From 3c71c4aa82fd7b6f6e8119ef7898420b9240adf7 Mon Sep 17 00:00:00 2001 From: momania Date: Wed, 24 Nov 2010 13:35:32 +0100 Subject: [PATCH] - fix race condition with timeout - improved stopping mechanism - renamed 'until' to 'forMax' for less confusion - ability to specif timeunit for timeout --- .../src/main/scala/akka/actor/FSM.scala | 91 ++++++++++++------- .../scala/akka/actor/actor/FSMActorSpec.scala | 13 ++- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/akka-actor/src/main/scala/akka/actor/FSM.scala b/akka-actor/src/main/scala/akka/actor/FSM.scala index df88db5ade..c801309e19 100644 --- a/akka-actor/src/main/scala/akka/actor/FSM.scala +++ b/akka-actor/src/main/scala/akka/actor/FSM.scala @@ -21,8 +21,10 @@ trait FSM[S, D] { register(stateName, stateFunction) } - protected final def startWith(stateName: S, stateData: D, timeout: Option[Long] = None) = { - setState(State(stateName, stateData, timeout)) + protected final def startWith(stateName: S, + stateData: D, + timeout: Option[(Long, TimeUnit)] = None) = { + applyState(State(stateName, stateData, timeout)) } protected final def goto(nextStateName: S): State = { @@ -42,8 +44,7 @@ trait FSM[S, D] { } protected final def stop(reason: Reason, stateData: D): State = { - self ! Stop(reason, stateData) - stay + stay using stateData withStopReason(reason) } def whenUnhandled(stateFunction: StateFunction) = { @@ -57,13 +58,15 @@ trait FSM[S, D] { /** FSM State data and default handlers */ private var currentState: State = _ private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None + private var generation: Long = 0L - private val transitions = mutable.Map[S, StateFunction]() + + private val stateFunctions = mutable.Map[S, StateFunction]() private def register(name: S, function: StateFunction) { - if (transitions contains name) { - transitions(name) = transitions(name) orElse function + if (stateFunctions contains name) { + stateFunctions(name) = stateFunctions(name) orElse function } else { - transitions(name) = function + stateFunctions(name) = function } } @@ -83,40 +86,55 @@ trait FSM[S, D] { } override final protected def receive: Receive = { - case Stop(reason, stateData) => - terminateEvent.apply(reason) - self.stop - case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) => - log.trace("Ignoring StateTimeout - ") - // state timeout when new message in queue, skip this timeout + case TimeoutMarker(gen) => + if (generation == gen) { + processEvent(StateTimeout) + } case value => { timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None} - val event = Event(value, currentState.stateData) - val nextState = (transitions(currentState.stateName) orElse handleEvent).apply(event) - setState(nextState) + generation += 1 + processEvent(value) } } - private def setState(nextState: State) = { - if (!transitions.contains(nextState.stateName)) { - stop(Failure("Next state %s does not exist".format(nextState.stateName))) - } else { - if (currentState != null && currentState.stateName != nextState.stateName) { - transitionEvent.apply(Transition(currentState.stateName, nextState.stateName)) - } - currentState = nextState - currentState.timeout.foreach { - t => - timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS)) - } + private def processEvent(value: Any) = { + val event = Event(value, currentState.stateData) + val nextState = (stateFunctions(currentState.stateName) orElse handleEvent).apply(event) + nextState.stopReason match { + case Some(reason) => terminate(reason) + case None => makeTransition(nextState) } } + private def makeTransition(nextState: State) = { + if (!stateFunctions.contains(nextState.stateName)) { + terminate(Failure("Next state %s does not exist".format(nextState.stateName))) + } else { + if (currentState.stateName != nextState.stateName) { + transitionEvent.apply(Transition(currentState.stateName, nextState.stateName)) + } + applyState(nextState) + } + } + + private def applyState(nextState: State) = { + currentState = nextState + currentState.timeout.foreach { t => + timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t._1, t._2)) + } + } + + private def terminate(reason: Reason) = { + terminateEvent.apply(reason) + self.stop + } + + case class Event(event: Any, stateData: D) - case class State(stateName: S, stateData: D, timeout: Option[Long] = None) { + case class State(stateName: S, stateData: D, timeout: Option[(Long, TimeUnit)] = None) { - def until(timeout: Long): State = { + def forMax(timeout: (Long, TimeUnit)): State = { copy(timeout = Some(timeout)) } @@ -131,6 +149,12 @@ trait FSM[S, D] { def using(nextStateDate: D): State = { copy(stateData = nextStateDate) } + + private[akka] var stopReason: Option[Reason] = None + private[akka] def withStopReason(reason: Reason): State = { + stopReason = Some(reason) + this + } } sealed trait Reason @@ -139,8 +163,7 @@ trait FSM[S, D] { case class Failure(cause: Any) extends Reason case object StateTimeout + case class TimeoutMarker(generation: Long) case class Transition(from: S, to: S) - - private case class Stop(reason: Reason, stateData: D) -} +} \ No newline at end of file diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala index 146c57759c..dd8de44bea 100644 --- a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala @@ -23,7 +23,7 @@ object FSMActorSpec { case object Locked extends LockState case object Open extends LockState - class Lock(code: String, timeout: Int) extends Actor with FSM[LockState, CodeState] { + class Lock(code: String, timeout: (Long, TimeUnit)) extends Actor with FSM[LockState, CodeState] { notifying { case Transition(Locked, Open) => transitionLatch.open @@ -37,7 +37,7 @@ object FSMActorSpec { stay using CodeState(incomplete, code) case codeTry if (codeTry == code) => { doUnlock - goto(Open) using CodeState("", code) until timeout + goto(Open) using CodeState("", code) forMax timeout } case wrong => { log.error("Wrong code %s", wrong) @@ -46,11 +46,11 @@ object FSMActorSpec { } } case Event("hello", _) => stay replying "world" - case Event("bye", _) => stop(Shutdown) + case Event("bye", _) => stop } when(Open) { - case Event(StateTimeout, stateData) => { + case Event(StateTimeout, _) => { doLock goto(Locked) } @@ -91,7 +91,7 @@ class FSMActorSpec extends JUnitSuite { def unlockTheLock = { // lock that locked after being open for 1 sec - val lock = Actor.actorOf(new Lock("33221", 1000)).start + val lock = Actor.actorOf(new Lock("33221", (1, TimeUnit.SECONDS))).start lock ! '3' lock ! '3' @@ -122,5 +122,4 @@ class FSMActorSpec extends JUnitSuite { tester ! Bye assert(terminatedLatch.tryAwait(2, TimeUnit.SECONDS)) } -} - +} \ No newline at end of file