polishing up code

This commit is contained in:
imn 2010-10-27 15:45:30 +02:00
parent efb99fc3a7
commit 847e0c7b04
2 changed files with 61 additions and 46 deletions

View file

@ -12,27 +12,15 @@ trait FSM[S, D] {
type StateFunction = scala.PartialFunction[Event, State]
private var currentState: State = _
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
private val transitions = mutable.Map[S, StateFunction]()
private def register(name: S, function: StateFunction) {
if (transitions contains name) {
transitions(name) = transitions(name) orElse function
} else {
transitions(name) = function
}
/** DSL */
protected final def inState(stateName: S)(stateFunction: StateFunction) = {
register(stateName, stateFunction)
}
protected final def setInitialState(stateName: S, stateData: D, timeout: Option[Long] = None) = {
setState(State(stateName, stateData, timeout))
}
protected final def inState(stateName: S)(stateFunction: StateFunction) = {
register(stateName, stateFunction)
}
protected final def goto(nextStateName: S): State = {
State(nextStateName, currentState.stateData)
}
@ -41,14 +29,6 @@ trait FSM[S, D] {
goto(currentState.stateName)
}
protected final def reply(replyValue: Any): State = {
self.sender.foreach(_ ! replyValue)
stay()
}
/**
* Stop
*/
protected final def stop(): State = {
stop(Normal)
}
@ -58,59 +38,82 @@ trait FSM[S, D] {
}
protected final def stop(reason: Reason, stateData: D): State = {
log.info("Stopped because of reason: %s", reason)
terminate(reason, currentState.stateName, stateData)
self.stop
State(currentState.stateName, stateData)
self ! Stop(reason, stateData)
stay
}
def terminate(reason: Reason, stateName: S, stateData: D) = ()
def whenUnhandled(stateFunction: StateFunction) = {
handleEvent = stateFunction
}
def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = {
terminateEvent = terminationHandler
}
/** FSM State data and default handlers */
private var currentState: State = _
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
private val transitions = mutable.Map[S, StateFunction]()
private def register(name: S, function: StateFunction) {
if (transitions contains name) {
transitions(name) = transitions(name) orElse function
} else {
transitions(name) = function
}
}
private var handleEvent: StateFunction = {
case Event(value, stateData) =>
log.warning("Event %s not handled in state %s - keeping current state with data %s", value, currentState.stateName, stateData)
currentState
log.warning("Event %s not handled in state %s, staying at current state", value, currentState.stateName)
stay
}
private var terminateEvent: PartialFunction[Reason, Unit] = {
case failure@Failure(_) => log.error("Stopping because of a %s", failure)
case reason => log.info("Stopping because of reason: %s", reason)
}
override final protected def receive: Receive = {
case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) => ()
case Stop(reason, stateData) =>
terminateEvent.apply(reason)
self.stop
case StateTimeout if (self.dispatcher.mailboxSize(self) > 0) =>
log.trace("Ignoring StateTimeout - ")
// state timeout when new message in queue, skip this timeout
case value => {
timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None}
val event = Event(value, currentState.stateData)
val nextState = (transitions(currentState.stateName) orElse handleEvent).apply(event)
if (self.isRunning) {
setState(nextState)
}
setState(nextState)
}
}
private def setState(nextState: State) = {
if (!transitions.contains(nextState.stateName)) {
stop(Failure("Next state %s not available".format(nextState.stateName)))
stop(Failure("Next state %s does not exist".format(nextState.stateName)))
} else {
currentState = nextState
currentState.timeout.foreach {t => timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS))}
currentState.timeout.foreach {
t =>
timeoutFuture = Some(Scheduler.scheduleOnce(self, StateTimeout, t, TimeUnit.MILLISECONDS))
}
}
}
case class Event(event: Any, stateData: D)
case class State(stateName: S, stateData: D, timeout: Option[Long] = None) {
def until(timeout: Long): State = {
copy(timeout = Some(timeout))
}
def then(nextStateName: S): State = {
copy(stateName = nextStateName)
}
def replying(replyValue:Any): State = {
self.sender.foreach(_ ! replyValue)
self.sender match {
case Some(sender) => sender ! replyValue
case None => log.error("Unable to send reply value %s, no sender reference to reply to", replyValue)
}
this
}
@ -125,4 +128,6 @@ trait FSM[S, D] {
case class Failure(cause: Any) extends Reason
case object StateTimeout
private case class Stop(reason: Reason, stateData: D)
}

View file

@ -16,6 +16,7 @@ object FSMActorSpec {
val unlockedLatch = new StandardLatch
val lockedLatch = new StandardLatch
val unhandledLatch = new StandardLatch
val terminatedLatch = new StandardLatch
sealed trait LockState
case object Locked extends LockState
@ -39,6 +40,7 @@ object FSMActorSpec {
}
}
case Event("hello", _) => stay replying "world"
case Event("bye", _) => stop(Shutdown)
}
inState(Open) {
@ -58,6 +60,10 @@ object FSMActorSpec {
}
}
onTermination {
case reason => terminatedLatch.open
}
private def doLock() {
log.info("Locked")
lockedLatch.open
@ -94,16 +100,20 @@ class FSMActorSpec extends JUnitSuite {
assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS))
val answerLatch = new StandardLatch
object Go
object Hello
object Bye
val tester = Actor.actorOf(new Actor {
protected def receive = {
case Go => lock ! "hello"
case Hello => lock ! "hello"
case "world" => answerLatch.open
case Bye => lock ! "bye"
}
}).start
tester ! Go
tester ! Hello
assert(answerLatch.tryAwait(2, TimeUnit.SECONDS))
tester ! Bye
assert(terminatedLatch.tryAwait(2, TimeUnit.SECONDS))
}
}