merge Irmo's changes and add test case for whenUnhandled

This commit is contained in:
Roland Kuhn 2011-01-03 21:01:11 +01:00
commit 12d4942853
3 changed files with 45 additions and 25 deletions

View file

@ -6,15 +6,10 @@ package akka.actor
import akka.util._
import scala.collection.mutable
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import java.util.concurrent.ScheduledFuture
object FSM {
case class Event[D](event: Any, stateData: D)
object Ev {
def unapply[D](e : Event[D]) : Option[Any] = Some(e.event)
}
case class Transition[S](from: S, to: S)
case class SubscribeTransitionCallBack(actorRef: ActorRef)
case class UnsubscribeTransitionCallBack(actorRef: ActorRef)
@ -23,7 +18,6 @@ object FSM {
case object Normal extends Reason
case object Shutdown extends Reason
case class Failure(cause: Any) extends Reason
case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D)
case object StateTimeout
case class TimeoutMarker(generation: Long)
@ -124,6 +118,7 @@ trait FSM[S, D] {
type StateFunction = scala.PartialFunction[Event[D], State]
type Timeout = Option[Duration]
type TransitionHandler = (S, S) => Unit
/* DSL */
@ -245,7 +240,7 @@ trait FSM[S, D] {
* Set handler which is called upon each state transition, i.e. not when
* staying in the same state.
*/
protected final def onTransition(transitionHandler: PartialFunction[Transition[S], Unit]) = {
protected final def onTransition(transitionHandler: TransitionHandler) = {
transitionEvent = transitionHandler
}
@ -293,12 +288,12 @@ trait FSM[S, D] {
}
}
private var handleEvent: StateFunction = handleEventDefault
private val handleEventDefault: StateFunction = {
case Event(value, stateData) =>
log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName)
stay
}
private var handleEvent: StateFunction = handleEventDefault
private var terminateEvent: PartialFunction[StopEvent[S,D], Unit] = {
case StopEvent(Failure(cause), _, _) =>
@ -306,8 +301,8 @@ trait FSM[S, D] {
case StopEvent(reason, _, _) => log.slf4j.info("Stopping because of reason: {}", reason)
}
private var transitionEvent: PartialFunction[Transition[S], Unit] = {
case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to)
private var transitionEvent: TransitionHandler = (from, to) => {
log.slf4j.debug("Transitioning from state {} to {}", from, to)
}
override final protected def receive: Receive = {
@ -358,9 +353,11 @@ trait FSM[S, D] {
terminate(Failure("Next state %s does not exist".format(nextState.stateName)))
} else {
if (currentState.stateName != nextState.stateName) {
val transition = Transition(currentState.stateName, nextState.stateName)
transitionEvent.apply(transition)
transitionCallBackList.foreach(_ ! transition)
transitionEvent.apply(currentState.stateName, nextState.stateName)
if (!transitionCallBackList.isEmpty) {
val transition = Transition(currentState.stateName, nextState.stateName)
transitionCallBackList.foreach(_ ! transition)
}
}
applyState(nextState)
}
@ -378,10 +375,15 @@ trait FSM[S, D] {
}
private def terminate(reason: Reason) = {
timers.foreach{ case (timer, t) => log.slf4j.info("Canceling timer {}", timer); t.cancel}
terminateEvent.apply(StopEvent(reason, currentState.stateName, currentState.stateData))
self.stop
}
case class Event[D](event: Any, stateData: D)
object Ev {
def unapply[D](e : Event[D]) : Option[Any] = Some(e.event)
}
case class State(stateName: S, stateData: D, timeout: Timeout = None) {
@ -423,4 +425,5 @@ trait FSM[S, D] {
}
}
case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D)
}

View file

@ -67,10 +67,10 @@ object FSMActorSpec {
}
}
onTransition {
onTransition((oldState, newState) => Transition(oldState, newState) match {
case Transition(Locked, Open) => transitionLatch.open
case Transition(_, _) => ()
}
})
onTermination {
case StopEvent(Shutdown, Locked, _) =>

View file

@ -55,11 +55,21 @@ class FSMTimingSpec
}
"notify unhandled messages" in {
within(200 millis) {
fsm ! Cancel
expectMsg(Unhandled(Cancel))
fsm ! TestUnhandled
within(100 millis) {
fsm ! Tick
expectNoMsg
}
within(100 millis) {
fsm ! SetHandler
fsm ! Tick
expectMsg(Unhandled(Tick))
expectNoMsg
}
within(100 millis) {
fsm ! Cancel
expectMsg(Transition(TestUnhandled, Initial))
}
}
}
@ -73,9 +83,11 @@ object FSMTimingSpec {
case object TestStateTimeout extends State
case object TestSingleTimer extends State
case object TestRepeatedTimer extends State
case object TestUnhandled extends State
case object Tick
case object Cancel
case object SetHandler
case class Unhandled(msg : AnyRef)
@ -84,13 +96,13 @@ object FSMTimingSpec {
startWith(Initial, ())
when(Initial) {
case Ev(TestStateTimeout) => goto(TestStateTimeout)
case Ev(TestSingleTimer) =>
setTimer("tester", Tick, 100 millis, false)
goto(TestSingleTimer)
case Ev(TestRepeatedTimer) =>
setTimer("tester", Tick, 100 millis, true)
goto(TestRepeatedTimer)
case Ev(x : FSMTimingSpec.State) => goto(x)
}
when(TestStateTimeout, stateTimeout = 100 millis) {
case Ev(StateTimeout) => goto(Initial)
@ -108,14 +120,19 @@ object FSMTimingSpec {
cancelTimer("tester")
goto(Initial)
}
whenUnhandled {
case Ev(msg : AnyRef) =>
tester ! Unhandled(msg)
when(TestUnhandled) {
case Ev(SetHandler) =>
whenUnhandled {
case Ev(msg : AnyRef) =>
tester ! Unhandled(msg)
stay
}
stay
case Ev(Cancel) =>
goto(Initial)
}
}
}
// vim: set ts=4 sw=4 et:
// vim: set ts=2 sw=2 et: