polishing up code
This commit is contained in:
parent
efb99fc3a7
commit
847e0c7b04
2 changed files with 61 additions and 46 deletions
|
|
@ -12,27 +12,15 @@ trait FSM[S, D] {
|
|||
|
||||
type StateFunction = scala.PartialFunction[Event, State]
|
||||
|
||||
private var currentState: State = _
|
||||
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
|
||||
|
||||
private val transitions = mutable.Map[S, StateFunction]()
|
||||
|
||||
private def register(name: S, function: StateFunction) {
|
||||
if (transitions contains name) {
|
||||
transitions(name) = transitions(name) orElse function
|
||||
} else {
|
||||
transitions(name) = function
|
||||
}
|
||||
/** DSL */
|
||||
protected final def inState(stateName: S)(stateFunction: StateFunction) = {
|
||||
register(stateName, stateFunction)
|
||||
}
|
||||
|
||||
protected final def setInitialState(stateName: S, stateData: D, timeout: Option[Long] = None) = {
|
||||
setState(State(stateName, stateData, timeout))
|
||||
}
|
||||
|
||||
protected final def inState(stateName: S)(stateFunction: StateFunction) = {
|
||||
register(stateName, stateFunction)
|
||||
}
|
||||
|
||||
protected final def goto(nextStateName: S): State = {
|
||||
State(nextStateName, currentState.stateData)
|
||||
}
|
||||
|
|
@ -41,14 +29,6 @@ trait FSM[S, D] {
|
|||
goto(currentState.stateName)
|
||||
}
|
||||
|
||||
protected final def reply(replyValue: Any): State = {
|
||||
self.sender.foreach(_ ! replyValue)
|
||||
stay()
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop
|
||||
*/
|
||||
protected final def stop(): State = {
|
||||
stop(Normal)
|
||||
}
|
||||
|
|
@ -58,59 +38,82 @@ trait FSM[S, D] {
|
|||
}
|
||||
|
||||
protected final def stop(reason: Reason, stateData: D): State = {
|
||||
log.info("Stopped because of reason: %s", reason)
|
||||
terminate(reason, currentState.stateName, stateData)
|
||||
self.stop
|
||||
State(currentState.stateName, stateData)
|
||||
self ! Stop(reason, stateData)
|
||||
stay
|
||||
}
|
||||
|
||||
def terminate(reason: Reason, stateName: S, stateData: D) = ()
|
||||
|
||||
def whenUnhandled(stateFunction: StateFunction) = {
|
||||
handleEvent = stateFunction
|
||||
}
|
||||
|
||||
def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = {
|
||||
terminateEvent = terminationHandler
|
||||
}
|
||||
|
||||
/** FSM State data and default handlers */
|
||||
private var currentState: State = _
|
||||
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
|
||||
|
||||
private val transitions = mutable.Map[S, StateFunction]()
|
||||
private def register(name: S, function: StateFunction) {
|
||||
if (transitions contains name) {
|
||||
transitions(name) = transitions(name) orElse function
|
||||
} else {
|
||||
transitions(name) = function
|
||||
}
|
||||
}
|
||||
|
||||
private var handleEvent: StateFunction = {
|
||||
case Event(value, stateData) =>
|
||||
log.warning("Event %s not handled in state %s - keeping current state with data %s", value, currentState.stateName, stateData)
|
||||
currentState
|
||||
log.warning("Event %s not handled in state %s, staying at current state", value, currentState.stateName)
|
||||
stay
|
||||
}
|
||||
|
||||
private var terminateEvent: PartialFunction[Reason, Unit] = {
|
||||
case failure@Failure(_) => log.error("Stopping because of a %s", failure)
|
||||
case reason => log.info("Stopping because of reason: %s", reason)
|
||||
}
|
||||
|
||||
override final protected def receive: Receive = {
|
||||
case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) => ()
|
||||
case Stop(reason, stateData) =>
|
||||
terminateEvent.apply(reason)
|
||||
self.stop
|
||||
case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) =>
|
||||
log.trace("Ignoring StateTimeout - ")
|
||||
// state timeout when new message in queue, skip this timeout
|
||||
case value => {
|
||||
timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None}
|
||||
val event = Event(value, currentState.stateData)
|
||||
val nextState = (transitions(currentState.stateName) orElse handleEvent).apply(event)
|
||||
if (self.isRunning) {
|
||||
setState(nextState)
|
||||
}
|
||||
setState(nextState)
|
||||
}
|
||||
}
|
||||
|
||||
private def setState(nextState: State) = {
|
||||
if (!transitions.contains(nextState.stateName)) {
|
||||
stop(Failure("Next state %s not available".format(nextState.stateName)))
|
||||
stop(Failure("Next state %s does not exist".format(nextState.stateName)))
|
||||
} else {
|
||||
currentState = nextState
|
||||
currentState.timeout.foreach {t => timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS))}
|
||||
currentState.timeout.foreach {
|
||||
t =>
|
||||
timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class Event(event: Any, stateData: D)
|
||||
|
||||
case class State(stateName: S, stateData: D, timeout: Option[Long] = None) {
|
||||
|
||||
def until(timeout: Long): State = {
|
||||
copy(timeout = Some(timeout))
|
||||
}
|
||||
|
||||
def then(nextStateName: S): State = {
|
||||
copy(stateName = nextStateName)
|
||||
}
|
||||
|
||||
def replying(replyValue:Any): State = {
|
||||
self.sender.foreach(_ ! replyValue)
|
||||
self.sender match {
|
||||
case Some(sender) => sender ! replyValue
|
||||
case None => log.error("Unable to send reply value %s, no sender reference to reply to", replyValue)
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
|
|
@ -125,4 +128,6 @@ trait FSM[S, D] {
|
|||
case class Failure(cause: Any) extends Reason
|
||||
|
||||
case object StateTimeout
|
||||
|
||||
private case class Stop(reason: Reason, stateData: D)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ object FSMActorSpec {
|
|||
val unlockedLatch = new StandardLatch
|
||||
val lockedLatch = new StandardLatch
|
||||
val unhandledLatch = new StandardLatch
|
||||
val terminatedLatch = new StandardLatch
|
||||
|
||||
sealed trait LockState
|
||||
case object Locked extends LockState
|
||||
|
|
@ -39,6 +40,7 @@ object FSMActorSpec {
|
|||
}
|
||||
}
|
||||
case Event("hello", _) => stay replying "world"
|
||||
case Event("bye", _) => stop(Shutdown)
|
||||
}
|
||||
|
||||
inState(Open) {
|
||||
|
|
@ -58,6 +60,10 @@ object FSMActorSpec {
|
|||
}
|
||||
}
|
||||
|
||||
onTermination {
|
||||
case reason => terminatedLatch.open
|
||||
}
|
||||
|
||||
private def doLock() {
|
||||
log.info("Locked")
|
||||
lockedLatch.open
|
||||
|
|
@ -94,16 +100,20 @@ class FSMActorSpec extends JUnitSuite {
|
|||
assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS))
|
||||
|
||||
val answerLatch = new StandardLatch
|
||||
object Go
|
||||
object Hello
|
||||
object Bye
|
||||
val tester = Actor.actorOf(new Actor {
|
||||
protected def receive = {
|
||||
case Go => lock ! "hello"
|
||||
case Hello => lock ! "hello"
|
||||
case "world" => answerLatch.open
|
||||
|
||||
case Bye => lock ! "bye"
|
||||
}
|
||||
}).start
|
||||
tester ! Go
|
||||
tester ! Hello
|
||||
assert(answerLatch.tryAwait(2, TimeUnit.SECONDS))
|
||||
|
||||
tester ! Bye
|
||||
assert(terminatedLatch.tryAwait(2, TimeUnit.SECONDS))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue