diff --git a/akka-actor/src/main/scala/actor/FSM.scala b/akka-actor/src/main/scala/actor/FSM.scala index c5eb00a6fd..eac861d358 100644 --- a/akka-actor/src/main/scala/actor/FSM.scala +++ b/akka-actor/src/main/scala/actor/FSM.scala @@ -12,27 +12,15 @@ trait FSM[S, D] { type StateFunction = scala.PartialFunction[Event, State] - private var currentState: State = _ - private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None - - private val transitions = mutable.Map[S, StateFunction]() - - private def register(name: S, function: StateFunction) { - if (transitions contains name) { - transitions(name) = transitions(name) orElse function - } else { - transitions(name) = function - } + /** DSL */ + protected final def inState(stateName: S)(stateFunction: StateFunction) = { + register(stateName, stateFunction) } protected final def setInitialState(stateName: S, stateData: D, timeout: Option[Long] = None) = { setState(State(stateName, stateData, timeout)) } - protected final def inState(stateName: S)(stateFunction: StateFunction) = { - register(stateName, stateFunction) - } - protected final def goto(nextStateName: S): State = { State(nextStateName, currentState.stateData) } @@ -41,14 +29,6 @@ trait FSM[S, D] { goto(currentState.stateName) } - protected final def reply(replyValue: Any): State = { - self.sender.foreach(_ ! replyValue) - stay() - } - - /** - * Stop - */ protected final def stop(): State = { stop(Normal) } @@ -58,59 +38,82 @@ trait FSM[S, D] { } protected final def stop(reason: Reason, stateData: D): State = { - log.info("Stopped because of reason: %s", reason) - terminate(reason, currentState.stateName, stateData) - self.stop - State(currentState.stateName, stateData) + self ! Stop(reason, stateData) + stay } - def terminate(reason: Reason, stateName: S, stateData: D) = () - def whenUnhandled(stateFunction: StateFunction) = { handleEvent = stateFunction } + def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = { + terminateEvent = terminationHandler + } + + /** FSM State data and default handlers */ + private var currentState: State = _ + private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None + + private val transitions = mutable.Map[S, StateFunction]() + private def register(name: S, function: StateFunction) { + if (transitions contains name) { + transitions(name) = transitions(name) orElse function + } else { + transitions(name) = function + } + } + private var handleEvent: StateFunction = { case Event(value, stateData) => - log.warning("Event %s not handled in state %s - keeping current state with data %s", value, currentState.stateName, stateData) - currentState + log.warning("Event %s not handled in state %s, staying at current state", value, currentState.stateName) + stay + } + + private var terminateEvent: PartialFunction[Reason, Unit] = { + case failure@Failure(_) => log.error("Stopping because of a %s", failure) + case reason => log.info("Stopping because of reason: %s", reason) } override final protected def receive: Receive = { - case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) => () + case Stop(reason, stateData) => + terminateEvent.apply(reason) + self.stop + case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) => + log.trace("Ignoring StateTimeout - ") // state timeout when new message in queue, skip this timeout case value => { timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None} val event = Event(value, currentState.stateData) val nextState = (transitions(currentState.stateName) orElse handleEvent).apply(event) - if (self.isRunning) { - setState(nextState) - } + setState(nextState) } } private def setState(nextState: State) = { if (!transitions.contains(nextState.stateName)) { - stop(Failure("Next state %s not available".format(nextState.stateName))) + stop(Failure("Next state %s does not exist".format(nextState.stateName))) } else { currentState = nextState - currentState.timeout.foreach {t => timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS))} + currentState.timeout.foreach { + t => + timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS)) + } } } case class Event(event: Any, stateData: D) case class State(stateName: S, stateData: D, timeout: Option[Long] = None) { + def until(timeout: Long): State = { copy(timeout = Some(timeout)) } - def then(nextStateName: S): State = { - copy(stateName = nextStateName) - } - def replying(replyValue:Any): State = { - self.sender.foreach(_ ! replyValue) + self.sender match { + case Some(sender) => sender ! replyValue + case None => log.error("Unable to send reply value %s, no sender reference to reply to", replyValue) + } this } @@ -125,4 +128,6 @@ trait FSM[S, D] { case class Failure(cause: Any) extends Reason case object StateTimeout + + private case class Stop(reason: Reason, stateData: D) } diff --git a/akka-actor/src/test/scala/actor/actor/FSMActorSpec.scala b/akka-actor/src/test/scala/actor/actor/FSMActorSpec.scala index 8646dd5561..dc6893c820 100644 --- a/akka-actor/src/test/scala/actor/actor/FSMActorSpec.scala +++ b/akka-actor/src/test/scala/actor/actor/FSMActorSpec.scala @@ -16,6 +16,7 @@ object FSMActorSpec { val unlockedLatch = new StandardLatch val lockedLatch = new StandardLatch val unhandledLatch = new StandardLatch + val terminatedLatch = new StandardLatch sealed trait LockState case object Locked extends LockState @@ -39,6 +40,7 @@ object FSMActorSpec { } } case Event("hello", _) => stay replying "world" + case Event("bye", _) => stop(Shutdown) } inState(Open) { @@ -58,6 +60,10 @@ object FSMActorSpec { } } + onTermination { + case reason => terminatedLatch.open + } + private def doLock() { log.info("Locked") lockedLatch.open @@ -94,16 +100,20 @@ class FSMActorSpec extends JUnitSuite { assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS)) val answerLatch = new StandardLatch - object Go + object Hello + object Bye val tester = Actor.actorOf(new Actor { protected def receive = { - case Go => lock ! "hello" + case Hello => lock ! "hello" case "world" => answerLatch.open - + case Bye => lock ! "bye" } }).start - tester ! Go + tester ! Hello assert(answerLatch.tryAwait(2, TimeUnit.SECONDS)) + + tester ! Bye + assert(terminatedLatch.tryAwait(2, TimeUnit.SECONDS)) } }