diff --git a/akka-actor/src/main/scala/akka/actor/FSM.scala b/akka-actor/src/main/scala/akka/actor/FSM.scala index 683b86b09e..0a8f3da77b 100755 --- a/akka-actor/src/main/scala/akka/actor/FSM.scala +++ b/akka-actor/src/main/scala/akka/actor/FSM.scala @@ -6,15 +6,10 @@ package akka.actor import akka.util._ import scala.collection.mutable -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.ScheduledFuture object FSM { - case class Event[D](event: Any, stateData: D) - object Ev { - def unapply[D](e : Event[D]) : Option[Any] = Some(e.event) - } - case class Transition[S](from: S, to: S) case class SubscribeTransitionCallBack(actorRef: ActorRef) case class UnsubscribeTransitionCallBack(actorRef: ActorRef) @@ -23,7 +18,6 @@ object FSM { case object Normal extends Reason case object Shutdown extends Reason case class Failure(cause: Any) extends Reason - case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D) case object StateTimeout case class TimeoutMarker(generation: Long) @@ -124,6 +118,7 @@ trait FSM[S, D] { type StateFunction = scala.PartialFunction[Event[D], State] type Timeout = Option[Duration] + type TransitionHandler = (S, S) => Unit /* DSL */ @@ -245,7 +240,7 @@ trait FSM[S, D] { * Set handler which is called upon each state transition, i.e. not when * staying in the same state. */ - protected final def onTransition(transitionHandler: PartialFunction[Transition[S], Unit]) = { + protected final def onTransition(transitionHandler: TransitionHandler) = { transitionEvent = transitionHandler } @@ -293,12 +288,12 @@ trait FSM[S, D] { } } - private var handleEvent: StateFunction = handleEventDefault private val handleEventDefault: StateFunction = { case Event(value, stateData) => log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName) stay } + private var handleEvent: StateFunction = handleEventDefault private var terminateEvent: PartialFunction[StopEvent[S,D], Unit] = { case StopEvent(Failure(cause), _, _) => @@ -306,8 +301,8 @@ trait FSM[S, D] { case StopEvent(reason, _, _) => log.slf4j.info("Stopping because of reason: {}", reason) } - private var transitionEvent: PartialFunction[Transition[S], Unit] = { - case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to) + private var transitionEvent: TransitionHandler = (from, to) => { + log.slf4j.debug("Transitioning from state {} to {}", from, to) } override final protected def receive: Receive = { @@ -358,9 +353,11 @@ trait FSM[S, D] { terminate(Failure("Next state %s does not exist".format(nextState.stateName))) } else { if (currentState.stateName != nextState.stateName) { - val transition = Transition(currentState.stateName, nextState.stateName) - transitionEvent.apply(transition) - transitionCallBackList.foreach(_ ! transition) + transitionEvent.apply(currentState.stateName, nextState.stateName) + if (!transitionCallBackList.isEmpty) { + val transition = Transition(currentState.stateName, nextState.stateName) + transitionCallBackList.foreach(_ ! transition) + } } applyState(nextState) } @@ -378,10 +375,15 @@ trait FSM[S, D] { } private def terminate(reason: Reason) = { + timers.foreach{ case (timer, t) => log.slf4j.info("Canceling timer {}", timer); t.cancel} terminateEvent.apply(StopEvent(reason, currentState.stateName, currentState.stateData)) self.stop } + case class Event[D](event: Any, stateData: D) + object Ev { + def unapply[D](e : Event[D]) : Option[Any] = Some(e.event) + } case class State(stateName: S, stateData: D, timeout: Timeout = None) { @@ -423,4 +425,5 @@ trait FSM[S, D] { } } + case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D) } diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala index 397517e05e..f9dc185597 100644 --- a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala @@ -67,10 +67,10 @@ object FSMActorSpec { } } - onTransition { + onTransition((oldState, newState) => Transition(oldState, newState) match { case Transition(Locked, Open) => transitionLatch.open case Transition(_, _) => () - } + }) onTermination { case StopEvent(Shutdown, Locked, _) => diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala index b0521bbc1b..c0ed57c47c 100644 --- a/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala @@ -55,11 +55,21 @@ class FSMTimingSpec } "notify unhandled messages" in { - within(200 millis) { - fsm ! Cancel - expectMsg(Unhandled(Cancel)) + fsm ! TestUnhandled + within(100 millis) { + fsm ! Tick expectNoMsg } + within(100 millis) { + fsm ! SetHandler + fsm ! Tick + expectMsg(Unhandled(Tick)) + expectNoMsg + } + within(100 millis) { + fsm ! Cancel + expectMsg(Transition(TestUnhandled, Initial)) + } } } @@ -73,9 +83,11 @@ object FSMTimingSpec { case object TestStateTimeout extends State case object TestSingleTimer extends State case object TestRepeatedTimer extends State + case object TestUnhandled extends State case object Tick case object Cancel + case object SetHandler case class Unhandled(msg : AnyRef) @@ -84,13 +96,13 @@ object FSMTimingSpec { startWith(Initial, ()) when(Initial) { - case Ev(TestStateTimeout) => goto(TestStateTimeout) case Ev(TestSingleTimer) => setTimer("tester", Tick, 100 millis, false) goto(TestSingleTimer) case Ev(TestRepeatedTimer) => setTimer("tester", Tick, 100 millis, true) goto(TestRepeatedTimer) + case Ev(x : FSMTimingSpec.State) => goto(x) } when(TestStateTimeout, stateTimeout = 100 millis) { case Ev(StateTimeout) => goto(Initial) @@ -108,14 +120,19 @@ object FSMTimingSpec { cancelTimer("tester") goto(Initial) } - - whenUnhandled { - case Ev(msg : AnyRef) => - tester ! Unhandled(msg) + when(TestUnhandled) { + case Ev(SetHandler) => + whenUnhandled { + case Ev(msg : AnyRef) => + tester ! Unhandled(msg) + stay + } stay + case Ev(Cancel) => + goto(Initial) } } } -// vim: set ts=4 sw=4 et: +// vim: set ts=2 sw=2 et: