diff --git a/akka-actor/src/main/scala/akka/actor/FSM.scala b/akka-actor/src/main/scala/akka/actor/FSM.scala index ae59c641bb..373d372677 100755 --- a/akka-actor/src/main/scala/akka/actor/FSM.scala +++ b/akka-actor/src/main/scala/akka/actor/FSM.scala @@ -9,50 +9,57 @@ import scala.collection.mutable import java.util.concurrent.{ScheduledFuture, TimeUnit} object FSM { + sealed trait Reason case object Normal extends Reason case object Shutdown extends Reason case class Failure(cause: Any) extends Reason + case class Event[D](event: Any, stateData: D) + + case class Transition[S](from: S, to: S) + case class SubscribeTransitionCallBack(actorRef: ActorRef) + case class UnsubscribeTransitionCallBack(actorRef: ActorRef) + case object StateTimeout case class TimeoutMarker(generation: Long) - case class Timer(name : String, msg : AnyRef, repeat : Boolean) { - private var ref : Option[ScheduledFuture[AnyRef]] = _ - def schedule(actor : ActorRef, timeout : Duration) { - if (repeat) { - ref = Some(Scheduler.schedule(actor, this, timeout.length, timeout.length, timeout.unit)) - } else { - ref = Some(Scheduler.scheduleOnce(actor, this, timeout.length, timeout.unit)) - } - } - def cancel { - ref = ref flatMap {t => t.cancel(true); None} - } + + case class Timer(name: String, msg: AnyRef, repeat: Boolean) { + private var ref: Option[ScheduledFuture[AnyRef]] = _ + + def schedule(actor: ActorRef, timeout: Duration) { + if (repeat) { + ref = Some(Scheduler.schedule(actor, this, timeout.length, timeout.length, timeout.unit)) + } else { + ref = Some(Scheduler.scheduleOnce(actor, this, timeout.length, timeout.unit)) + } + } + + def cancel { + ref = ref flatMap { + t => t.cancel(true); None + } + } } - + /* - * With these implicits in scope, you can write "5 seconds" anywhere a - * Duration or Option[Duration] is expected. This is conveniently true - * for derived classes. - */ - implicit def d2od(d : Duration) : Option[Duration] = Some(d) - implicit def p2od(p : (Long, TimeUnit)) : Duration = new Duration(p._1, p._2) - implicit def i2d(i : Int) : DurationInt = new DurationInt(i) - implicit def l2d(l : Long) : DurationLong = new DurationLong(l) + * With these implicits in scope, you can write "5 seconds" anywhere a + * Duration or Option[Duration] is expected. This is conveniently true + * for derived classes. + */ + implicit def d2od(d: Duration): Option[Duration] = Some(d) + implicit def p2od(p: (Long, TimeUnit)): Duration = new Duration(p._1, p._2) } trait FSM[S, D] { this: Actor => + import FSM._ - type StateFunction = scala.PartialFunction[Event, State] + type StateFunction = scala.PartialFunction[Event[D], State] type Timeout = Option[Duration] - /** DSL */ - protected final def notifying(transitionHandler: PartialFunction[Transition, Unit]) = { - transitionEvent = transitionHandler - } - + /**DSL */ protected final def when(stateName: S, stateTimeout: Timeout = None)(stateFunction: StateFunction) = { register(stateName, stateFunction, stateTimeout) } @@ -60,7 +67,7 @@ trait FSM[S, D] { protected final def startWith(stateName: S, stateData: D, timeout: Timeout = None) = { - currentState = State(stateName, stateData, timeout) + applyState(State(stateName, stateData, timeout)) } protected final def goto(nextStateName: S): State = { @@ -81,7 +88,7 @@ trait FSM[S, D] { } protected final def stop(reason: Reason, stateData: D): State = { - stay using stateData withStopReason(reason) + stay using stateData withStopReason (reason) } /** @@ -92,53 +99,60 @@ trait FSM[S, D] { * @param repeat send once if false, scheduleAtFixedRate if true * @return current State */ - protected final def setTimer(name : String, msg : AnyRef, timeout : Duration, repeat : Boolean):State = { + protected final def setTimer(name: String, msg: AnyRef, timeout: Duration, repeat: Boolean): State = { if (timers contains name) { - timers(name).cancel + timers(name).cancel } - val timer = Timer(name, msg, repeat) + val timer = Timer(name, msg, repeat) timer.schedule(self, timeout) timers(name) = timer stay } - + /** * Cancel named timer, ensuring that the message is not subsequently delivered (no race). * @param name * @return */ - protected final def cancelTimer(name : String) = { - if (timers contains name) { - timers(name).cancel - timers -= name - } + protected final def cancelTimer(name: String) = { + if (timers contains name) { + timers(name).cancel + timers -= name + } } - - protected final def timerActive_?(name : String) = timers contains name - def whenUnhandled(stateFunction: StateFunction) = { + protected final def timerActive_?(name: String) = timers contains name + + /**callbacks */ + protected final def onTransition(transitionHandler: PartialFunction[Transition[S], Unit]) = { + transitionEvent = transitionHandler + } + + protected final def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = { + terminateEvent = terminationHandler + } + + protected final def whenUnhandled(stateFunction: StateFunction) = { handleEvent = stateFunction } - def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = { - terminateEvent = terminationHandler - } - def initialize { - // check existence of initial state and setup timeout - makeTransition(currentState) + // check existence of initial state and setup timeout + makeTransition(currentState) } - /** FSM State data and default handlers */ + /**FSM State data and default handlers */ private var currentState: State = _ private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None private var generation: Long = 0L + private var transitionCallBackList: List[ActorRef] = Nil + private val timers = mutable.Map[String, Timer]() private val stateFunctions = mutable.Map[S, StateFunction]() private val stateTimeouts = mutable.Map[S, Timeout]() - + private def register(name: S, function: StateFunction, timeout: Timeout) { if (stateFunctions contains name) { stateFunctions(name) = stateFunctions(name) orElse function @@ -160,7 +174,7 @@ trait FSM[S, D] { case reason => log.slf4j.info("Stopping because of reason: {}", reason) } - private var transitionEvent: PartialFunction[Transition, Unit] = { + private var transitionEvent: PartialFunction[Transition[S], Unit] = { case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to) } @@ -169,15 +183,23 @@ trait FSM[S, D] { if (generation == gen) { processEvent(StateTimeout) } - case t @ Timer(name, msg, repeat) => + case t@Timer(name, msg, repeat) => if (timerActive_?(name)) { - processEvent(msg) - if (!repeat) { - timers -= name - } + processEvent(msg) + if (!repeat) { + timers -= name + } } + case SubscribeTransitionCallBack(actorRef) => + // send current state back as reference point + actorRef ! currentState.stateName + transitionCallBackList ::= actorRef + case UnsubscribeTransitionCallBack(actorRef) => + transitionCallBackList = transitionCallBackList.filterNot(_ == actorRef) case value => { - timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None} + timeoutFuture = timeoutFuture.flatMap{ + ref => ref.cancel(true); None + } generation += 1 processEvent(value) } @@ -197,14 +219,21 @@ trait FSM[S, D] { terminate(Failure("Next state %s does not exist".format(nextState.stateName))) } else { if (currentState.stateName != nextState.stateName) { - transitionEvent.apply(Transition(currentState.stateName, nextState.stateName)) - } - currentState = nextState - currentState.timeout orElse stateTimeouts(currentState.stateName) foreach { t => - if (t.length >= 0) { - timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t.length, t.unit)) - } + val transition = Transition(currentState.stateName, nextState.stateName) + transitionEvent.apply(transition) + transitionCallBackList.foreach(_ ! transition) } + applyState(nextState) + } + } + + private def applyState(nextState: State) = { + currentState = nextState + currentState.timeout orElse stateTimeouts(currentState.stateName) foreach { + t => + if (t.length >= 0) { + timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t.length, t.unit)) + } } } @@ -214,15 +243,13 @@ trait FSM[S, D] { } - case class Event(event: Any, stateData: D) - case class State(stateName: S, stateData: D, timeout: Timeout = None) { def forMax(timeout: Duration): State = { copy(timeout = Some(timeout)) } - - def replying(replyValue:Any): State = { + + def replying(replyValue: Any): State = { self.sender match { case Some(sender) => sender ! replyValue case None => log.slf4j.error("Unable to send reply value {}, no sender reference to reply to", replyValue) @@ -235,11 +262,11 @@ trait FSM[S, D] { } private[akka] var stopReason: Option[Reason] = None + private[akka] def withStopReason(reason: Reason): State = { stopReason = Some(reason) this } } - case class Transition(from: S, to: S) } diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala index ed9a433b73..ab7c052e8a 100644 --- a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala @@ -6,19 +6,21 @@ package akka.actor import org.scalatest.junit.JUnitSuite import org.junit.Test +import FSM._ import org.multiverse.api.latches.StandardLatch import java.util.concurrent.TimeUnit object FSMActorSpec { - import FSM._ + val unlockedLatch = new StandardLatch val lockedLatch = new StandardLatch val unhandledLatch = new StandardLatch val terminatedLatch = new StandardLatch val transitionLatch = new StandardLatch + val transitionCallBackLatch = new StandardLatch sealed trait LockState case object Locked extends LockState @@ -26,11 +28,6 @@ object FSMActorSpec { class Lock(code: String, timeout: (Long, TimeUnit)) extends Actor with FSM[LockState, CodeState] { - notifying { - case Transition(Locked, Open) => transitionLatch.open - case Transition(_, _) => () - } - when(Locked) { case Event(digit: Char, CodeState(soFar, code)) => { soFar + digit match { @@ -57,8 +54,6 @@ object FSMActorSpec { } } - startWith(Locked, CodeState("", code)) - whenUnhandled { case Event(_, stateData) => { log.slf4j.info("Unhandled") @@ -67,10 +62,17 @@ object FSMActorSpec { } } + onTransition { + case Transition(Locked, Open) => transitionLatch.open + case Transition(_, _) => () + } + onTermination { case reason => terminatedLatch.open } + startWith(Locked, CodeState("", code)) + private def doLock() { log.slf4j.info("Locked") lockedLatch.open @@ -88,12 +90,19 @@ object FSMActorSpec { class FSMActorSpec extends JUnitSuite { import FSMActorSpec._ + @Test def unlockTheLock = { // lock that locked after being open for 1 sec val lock = Actor.actorOf(new Lock("33221", (1, TimeUnit.SECONDS))).start + val transitionTester = Actor.actorOf(new Actor { def receive = { + case Transition(_, _) => transitionCallBackLatch.open + }}).start + + lock ! SubscribeTransitionCallBack(transitionTester) + lock ! '3' lock ! '3' lock ! '2' @@ -102,8 +111,10 @@ class FSMActorSpec extends JUnitSuite { assert(unlockedLatch.tryAwait(1, TimeUnit.SECONDS)) assert(transitionLatch.tryAwait(1, TimeUnit.SECONDS)) + assert(transitionCallBackLatch.tryAwait(1, TimeUnit.SECONDS)) assert(lockedLatch.tryAwait(2, TimeUnit.SECONDS)) + lock ! "not_handled" assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS)) diff --git a/akka-actor/src/main/scala/akka/actor/Buncher.scala b/akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala old mode 100755 new mode 100644 similarity index 80% rename from akka-actor/src/main/scala/akka/actor/Buncher.scala rename to akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala index ce926e9d15..e9d54cccbe --- a/akka-actor/src/main/scala/akka/actor/Buncher.scala +++ b/akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala @@ -1,18 +1,19 @@ -package akka.actor +package sample.fsm.buncher import scala.reflect.ClassManifest import akka.util.Duration +import akka.actor.{FSM, Actor} /* - * generic typed object buncher. - * - * To instantiate it, use the factory method like so: - * Buncher(100, 500)(x : List[AnyRef] => x foreach println) - * which will yield a fully functional and started ActorRef. - * The type of messages allowed is strongly typed to match the - * supplied processing method; other messages are discarded (and - * possibly logged). - */ +* generic typed object buncher. +* +* To instantiate it, use the factory method like so: +* Buncher(100, 500)(x : List[AnyRef] => x foreach println) +* which will yield a fully functional and started ActorRef. +* The type of messages allowed is strongly typed to match the +* supplied processing method; other messages are discarded (and +* possibly logged). +*/ object Buncher { trait State case object Idle extends State diff --git a/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala b/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala index 4104f6c18f..1147e6d0bd 100644 --- a/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala +++ b/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala @@ -3,8 +3,8 @@ package sample.fsm.dining.fsm import akka.actor.{ActorRef, Actor, FSM} import akka.actor.FSM._ import Actor._ -import java.util.concurrent.TimeUnit -import TimeUnit._ +import akka.util.Duration +import akka.util.duration._ /* * Some messages for the chopstick @@ -84,7 +84,7 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit when(Waiting) { case Event(Think, _) => log.info("%s starts to think", name) - startThinking(5, SECONDS) + startThinking(5 seconds) } //When a hakker is thinking it can become hungry @@ -118,12 +118,12 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit case Event(Busy(chopstick), TakenChopsticks(leftOption, rightOption)) => leftOption.foreach(_ ! Put) rightOption.foreach(_ ! Put) - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) } private def startEating(left: ActorRef, right: ActorRef): State = { log.info("%s has picked up %s and %s, and starts to eat", name, left.id, right.id) - goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5, SECONDS) + goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5 seconds) } // When the results of the other grab comes back, @@ -132,9 +132,9 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit when(FirstChopstickDenied) { case Event(Taken(secondChopstick), _) => secondChopstick ! Put - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) case Event(Busy(chopstick), _) => - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) } // When a hakker is eating, he can decide to start to think, @@ -144,11 +144,11 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit log.info("%s puts down his chopsticks and starts to think", name) left ! Put right ! Put - startThinking(5, SECONDS) + startThinking(5 seconds) } - private def startThinking(period: Int, timeUnit: TimeUnit): State = { - goto(Thinking) using TakenChopsticks(None, None) forMax (period, timeUnit) + private def startThinking(duration: Duration): State = { + goto(Thinking) using TakenChopsticks(None, None) forMax duration } //All hakkers start waiting