improvements on FSM

- change to akka.util.Duration
- add proper implicits to enable timeouts like "5 seconds"
- add concept of state timeouts, which are actually attached to states
- add timer handling for conveniently modeling timing irrespective of message "interruptions"
- add generic Buncher class as usage example and useful utility
This commit is contained in:
Roland Kuhn 2010-12-19 22:16:15 +01:00
parent c93a447bf1
commit 30c11feb78
2 changed files with 319 additions and 172 deletions

View file

@ -0,0 +1,74 @@
package akka.actor
import scala.reflect.ClassManifest
import akka.util.Duration
/*
* generic typed object buncher.
*
* To instantiate it, use the factory method like so:
* Buncher(100, 500)(x : List[AnyRef] => x foreach println)
* which will yield a fully functional and started ActorRef.
* The type of messages allowed is strongly typed to match the
* supplied processing method; other messages are discarded (and
* possibly logged).
*/
object Buncher {
trait State
case object Idle extends State
case object Active extends State
case object Flush // send out current queue immediately
case object Stop // poison pill
case class Data[A](start : Long, xs : List[A])
def apply[A : Manifest](singleTimeout : Duration,
multiTimeout : Duration)(f : List[A] => Unit) =
Actor.actorOf(new Buncher[A](singleTimeout, multiTimeout).deliver(f))
}
class Buncher[A : Manifest] private (val singleTimeout : Duration, val multiTimeout : Duration)
extends Actor with FSM[Buncher.State, Buncher.Data[A]] {
import Buncher._
import FSM._
private val manifestA = manifest[A]
private var send : List[A] => Unit = _
private def deliver(f : List[A] => Unit) = { send = f; this }
private def now = System.currentTimeMillis
private def check(m : AnyRef) = ClassManifest.fromClass(m.getClass) <:< manifestA
startWith(Idle, Data(0, Nil))
when(Idle) {
case Event(m : AnyRef, _) if check(m) =>
goto(Active) using Data(now, m.asInstanceOf[A] :: Nil)
case Event(Flush, _) => stay
case Event(Stop, _) => stop
}
when(Active, stateTimeout = Some(singleTimeout)) {
case Event(m : AnyRef, Data(start, xs)) if check(m) =>
val l = m.asInstanceOf[A] :: xs
if (now - start > multiTimeout.toMillis) {
send(l.reverse)
goto(Idle) using Data(0, Nil)
} else {
stay using Data(start, l)
}
case Event(StateTimeout, Data(_, xs)) =>
send(xs.reverse)
goto(Idle) using Data(0, Nil)
case Event(Flush, Data(_, xs)) =>
send(xs.reverse)
goto(Idle) using Data(0, Nil)
case Event(Stop, Data(_, xs)) =>
send(xs.reverse)
stop
}
initialize
}

417
akka-actor/src/main/scala/akka/actor/FSM.scala Normal file → Executable file
View file

@ -1,172 +1,245 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package akka.actor
import scala.collection.mutable
import java.util.concurrent.{ScheduledFuture, TimeUnit}
object FSM {
sealed trait Reason
case object Normal extends Reason
case object Shutdown extends Reason
case class Failure(cause: Any) extends Reason
case object StateTimeout
case class TimeoutMarker(generation: Long)
}
trait FSM[S, D] {
this: Actor =>
import FSM._
type StateFunction = scala.PartialFunction[Event, State]
/** DSL */
protected final def notifying(transitionHandler: PartialFunction[Transition, Unit]) = {
transitionEvent = transitionHandler
}
protected final def when(stateName: S)(stateFunction: StateFunction) = {
register(stateName, stateFunction)
}
protected final def startWith(stateName: S,
stateData: D,
timeout: Option[(Long, TimeUnit)] = None) = {
applyState(State(stateName, stateData, timeout))
}
protected final def goto(nextStateName: S): State = {
State(nextStateName, currentState.stateData)
}
protected final def stay(): State = {
goto(currentState.stateName)
}
protected final def stop(): State = {
stop(Normal)
}
protected final def stop(reason: Reason): State = {
stop(reason, currentState.stateData)
}
protected final def stop(reason: Reason, stateData: D): State = {
stay using stateData withStopReason(reason)
}
def whenUnhandled(stateFunction: StateFunction) = {
handleEvent = stateFunction
}
def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = {
terminateEvent = terminationHandler
}
/** FSM State data and default handlers */
private var currentState: State = _
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
private var generation: Long = 0L
private val stateFunctions = mutable.Map[S, StateFunction]()
private def register(name: S, function: StateFunction) {
if (stateFunctions contains name) {
stateFunctions(name) = stateFunctions(name) orElse function
} else {
stateFunctions(name) = function
}
}
private var handleEvent: StateFunction = {
case Event(value, stateData) =>
log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName)
stay
}
private var terminateEvent: PartialFunction[Reason, Unit] = {
case failure@Failure(_) => log.slf4j.error("Stopping because of a {}", failure)
case reason => log.slf4j.info("Stopping because of reason: {}", reason)
}
private var transitionEvent: PartialFunction[Transition, Unit] = {
case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to)
}
override final protected def receive: Receive = {
case TimeoutMarker(gen) =>
if (generation == gen) {
processEvent(StateTimeout)
}
case value => {
timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None}
generation += 1
processEvent(value)
}
}
private def processEvent(value: Any) = {
val event = Event(value, currentState.stateData)
val nextState = (stateFunctions(currentState.stateName) orElse handleEvent).apply(event)
nextState.stopReason match {
case Some(reason) => terminate(reason)
case None => makeTransition(nextState)
}
}
private def makeTransition(nextState: State) = {
if (!stateFunctions.contains(nextState.stateName)) {
terminate(Failure("Next state %s does not exist".format(nextState.stateName)))
} else {
if (currentState.stateName != nextState.stateName) {
transitionEvent.apply(Transition(currentState.stateName, nextState.stateName))
}
applyState(nextState)
}
}
private def applyState(nextState: State) = {
currentState = nextState
currentState.timeout.foreach { t =>
timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t._1, t._2))
}
}
private def terminate(reason: Reason) = {
terminateEvent.apply(reason)
self.stop
}
case class Event(event: Any, stateData: D)
case class State(stateName: S, stateData: D, timeout: Option[(Long, TimeUnit)] = None) {
def forMax(timeout: (Long, TimeUnit)): State = {
copy(timeout = Some(timeout))
}
def replying(replyValue:Any): State = {
self.sender match {
case Some(sender) => sender ! replyValue
case None => log.slf4j.error("Unable to send reply value {}, no sender reference to reply to", replyValue)
}
this
}
def using(nextStateDate: D): State = {
copy(stateData = nextStateDate)
}
private[akka] var stopReason: Option[Reason] = None
private[akka] def withStopReason(reason: Reason): State = {
stopReason = Some(reason)
this
}
}
case class Transition(from: S, to: S)
}
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package akka.actor
import akka.util._
import scala.collection.mutable
import java.util.concurrent.{ScheduledFuture, TimeUnit}
object FSM {
sealed trait Reason
case object Normal extends Reason
case object Shutdown extends Reason
case class Failure(cause: Any) extends Reason
case object StateTimeout
case class TimeoutMarker(generation: Long)
case class Timer(name : String, msg : AnyRef, repeat : Boolean) {
private var ref : Option[ScheduledFuture[AnyRef]] = _
def schedule(actor : ActorRef, timeout : Duration) {
if (repeat) {
ref = Some(Scheduler.schedule(actor, this, timeout.length, timeout.length, timeout.unit))
} else {
ref = Some(Scheduler.scheduleOnce(actor, this, timeout.length, timeout.unit))
}
}
def cancel {
ref = ref flatMap {t => t.cancel(true); None}
}
}
/*
* With these implicits in scope, you can write "5 seconds" anywhere a
* Duration or Option[Duration] is expected. This is conveniently true
* for derived classes.
*/
implicit def d2od(d : Duration) : Option[Duration] = Some(d)
implicit def p2od(p : (Long, TimeUnit)) : Duration = new Duration(p._1, p._2)
implicit def i2d(i : Int) : DurationInt = new DurationInt(i)
implicit def l2d(l : Long) : DurationLong = new DurationLong(l)
}
trait FSM[S, D] {
this: Actor =>
import FSM._
type StateFunction = scala.PartialFunction[Event, State]
type Timeout = Option[Duration]
/** DSL */
protected final def notifying(transitionHandler: PartialFunction[Transition, Unit]) = {
transitionEvent = transitionHandler
}
protected final def when(stateName: S, stateTimeout: Timeout = None)(stateFunction: StateFunction) = {
register(stateName, stateFunction, stateTimeout)
}
protected final def startWith(stateName: S,
stateData: D,
timeout: Timeout = None) = {
currentState = State(stateName, stateData, timeout)
}
protected final def goto(nextStateName: S): State = {
State(nextStateName, currentState.stateData)
}
protected final def stay(): State = {
// cannot directly use currentState because of the timeout field
goto(currentState.stateName)
}
protected final def stop(): State = {
stop(Normal)
}
protected final def stop(reason: Reason): State = {
stop(reason, currentState.stateData)
}
protected final def stop(reason: Reason, stateData: D): State = {
stay using stateData withStopReason(reason)
}
/**
* Schedule named timer to deliver message after given delay, possibly repeating.
* @param name identifier to be used with cancelTimer()
* @param msg message to be delivered
* @param timeout delay of first message delivery and between subsequent messages
* @param repeat send once if false, scheduleAtFixedRate if true
* @return current State
*/
protected final def setTimer(name : String, msg : AnyRef, timeout : Duration, repeat : Boolean):State = {
if (timers contains name) {
timers(name).cancel
}
val timer = Timer(name, msg, repeat)
timer.schedule(self, timeout)
timers(name) = timer
stay
}
/**
* Cancel named timer, ensuring that the message is not subsequently delivered (no race).
* @param name
* @return
*/
protected final def cancelTimer(name : String) = {
if (timers contains name) {
timers(name).cancel
timers -= name
}
}
protected final def timerActive_?(name : String) = timers contains name
def whenUnhandled(stateFunction: StateFunction) = {
handleEvent = stateFunction
}
def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = {
terminateEvent = terminationHandler
}
def initialize {
// check existence of initial state and setup timeout
makeTransition(currentState)
}
/** FSM State data and default handlers */
private var currentState: State = _
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
private var generation: Long = 0L
private val timers = mutable.Map[String, Timer]()
private val stateFunctions = mutable.Map[S, StateFunction]()
private val stateTimeouts = mutable.Map[S, Timeout]()
private def register(name: S, function: StateFunction, timeout: Timeout) {
if (stateFunctions contains name) {
stateFunctions(name) = stateFunctions(name) orElse function
stateTimeouts(name) = timeout orElse stateTimeouts(name)
} else {
stateFunctions(name) = function
stateTimeouts(name) = timeout
}
}
private var handleEvent: StateFunction = {
case Event(value, stateData) =>
log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName)
stay
}
private var terminateEvent: PartialFunction[Reason, Unit] = {
case failure@Failure(_) => log.slf4j.error("Stopping because of a {}", failure)
case reason => log.slf4j.info("Stopping because of reason: {}", reason)
}
private var transitionEvent: PartialFunction[Transition, Unit] = {
case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to)
}
override final protected def receive: Receive = {
case TimeoutMarker(gen) =>
if (generation == gen) {
processEvent(StateTimeout)
}
case t @ Timer(name, msg, repeat) =>
if (timerActive_?(name)) {
processEvent(msg)
if (!repeat) {
timers -= name
}
}
case value => {
timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None}
generation += 1
processEvent(value)
}
}
private def processEvent(value: Any) = {
val event = Event(value, currentState.stateData)
val nextState = (stateFunctions(currentState.stateName) orElse handleEvent).apply(event)
nextState.stopReason match {
case Some(reason) => terminate(reason)
case None => makeTransition(nextState)
}
}
private def makeTransition(nextState: State) = {
if (!stateFunctions.contains(nextState.stateName)) {
terminate(Failure("Next state %s does not exist".format(nextState.stateName)))
} else {
if (currentState.stateName != nextState.stateName) {
transitionEvent.apply(Transition(currentState.stateName, nextState.stateName))
}
currentState = nextState
currentState.timeout orElse stateTimeouts(currentState.stateName) foreach { t =>
if (t.length >= 0) {
timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t.length, t.unit))
}
}
}
}
private def terminate(reason: Reason) = {
terminateEvent.apply(reason)
self.stop
}
case class Event(event: Any, stateData: D)
case class State(stateName: S, stateData: D, timeout: Timeout = None) {
def forMax(timeout: Duration): State = {
copy(timeout = Some(timeout))
}
def replying(replyValue:Any): State = {
self.sender match {
case Some(sender) => sender ! replyValue
case None => log.slf4j.error("Unable to send reply value {}, no sender reference to reply to", replyValue)
}
this
}
def using(nextStateDate: D): State = {
copy(stateData = nextStateDate)
}
private[akka] var stopReason: Option[Reason] = None
private[akka] def withStopReason(reason: Reason): State = {
stopReason = Some(reason)
this
}
}
case class Transition(from: S, to: S)
}