Merge branch 'testkit'

This commit is contained in:
Roland Kuhn 2011-01-04 13:51:39 +01:00
commit cb3135c9b7
9 changed files with 1364 additions and 110 deletions

330
akka-actor/src/main/scala/akka/actor/FSM.scala Normal file → Executable file
View file

@ -1,13 +1,20 @@
/**
* Copyright (C) 2009-2011 Scalable Solutions AB <http://scalablesolutions.se>
*/
package akka.actor
import akka.util._
import scala.collection.mutable
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import java.util.concurrent.ScheduledFuture
object FSM {
case class CurrentState[S](fsmRef: ActorRef, state: S)
case class Transition[S](fsmRef: ActorRef, from: S, to: S)
case class SubscribeTransitionCallBack(actorRef: ActorRef)
case class UnsubscribeTransitionCallBack(actorRef: ActorRef)
sealed trait Reason
case object Normal extends Reason
case object Shutdown extends Reason
@ -15,85 +22,288 @@ object FSM {
case object StateTimeout
case class TimeoutMarker(generation: Long)
case class Timer(name: String, msg: AnyRef, repeat: Boolean) {
private var ref: Option[ScheduledFuture[AnyRef]] = _
def schedule(actor: ActorRef, timeout: Duration) {
if (repeat) {
ref = Some(Scheduler.schedule(actor, this, timeout.length, timeout.length, timeout.unit))
} else {
ref = Some(Scheduler.scheduleOnce(actor, this, timeout.length, timeout.unit))
}
}
def cancel {
if (ref.isDefined) {
ref.get.cancel(true)
ref = None
}
}
}
/*
* With these implicits in scope, you can write "5 seconds" anywhere a
* Duration or Option[Duration] is expected. This is conveniently true
* for derived classes.
*/
implicit def d2od(d: Duration): Option[Duration] = Some(d)
}
/**
* Finite State Machine actor trait. Use as follows:
*
* <pre>
* object A {
* trait State
* case class One extends State
* case class Two extends State
*
* case class Data(i : Int)
* }
*
* class A extends Actor with FSM[A.State, A.Data] {
* import A._
*
* startWith(One, Data(42))
* when(One) {
* case Event(SomeMsg, Data(x)) => ...
* case Ev(SomeMsg) => ... // convenience when data not needed
* }
* when(Two, stateTimeout = 5 seconds) { ... }
* initialize
* }
* </pre>
*
* Within the partial function the following values are returned for effecting
* state transitions:
*
* - <code>stay</code> for staying in the same state
* - <code>stay using Data(...)</code> for staying in the same state, but with
* different data
* - <code>stay forMax 5.millis</code> for staying with a state timeout; can be
* combined with <code>using</code>
* - <code>goto(...)</code> for changing into a different state; also supports
* <code>using</code> and <code>forMax</code>
* - <code>stop</code> for terminating this FSM actor
*
* Each of the above also supports the method <code>replying(AnyRef)</code> for
* sending a reply before changing state.
*
* Another feature is that other actors may subscribe for transition events by
* sending a <code>SubscribeTransitionCallback</code> message to this actor;
* use <code>UnsubscribeTransitionCallback</code> before stopping the other
* actor.
*
* State timeouts set an upper bound to the time which may pass before another
* message is received in the current state. If no external message is
* available, then upon expiry of the timeout a StateTimeout message is sent.
* Note that this message will only be received in the state for which the
* timeout was set and that any message received will cancel the timeout
* (possibly to be started again by the next transition).
*
* Another feature is the ability to install and cancel single-shot as well as
* repeated timers which arrange for the sending of a user-specified message:
*
* <pre>
* setTimer("tock", TockMsg, 1 second, true) // repeating
* setTimer("lifetime", TerminateMsg, 1 hour, false) // single-shot
* cancelTimer("tock")
* timerActive_? ("tock")
* </pre>
*/
trait FSM[S, D] {
this: Actor =>
import FSM._
type StateFunction = scala.PartialFunction[Event, State]
type StateFunction = scala.PartialFunction[Event[D], State]
type Timeout = Option[Duration]
type TransitionHandler = (S, S) => Unit
/** DSL */
protected final def notifying(transitionHandler: PartialFunction[Transition, Unit]) = {
transitionEvent = transitionHandler
}
protected final def when(stateName: S)(stateFunction: StateFunction) = {
register(stateName, stateFunction)
/* DSL */
/**
* Insert a new StateFunction at the end of the processing chain for the
* given state. If the stateTimeout parameter is set, entering this state
* without a differing explicit timeout setting will trigger a StateTimeout
* event; the same is true when using #stay.
*
* @param stateName designator for the state
* @param stateTimeout default state timeout for this state
* @param stateFunction partial function describing response to input
*/
protected final def when(stateName: S, stateTimeout: Timeout = None)(stateFunction: StateFunction) = {
register(stateName, stateFunction, stateTimeout)
}
/**
* Set initial state. Call this method from the constructor before the #initialize method.
*
* @param stateName initial state designator
* @param stateData initial state data
* @param timeout state timeout for the initial state, overriding the default timeout for that state
*/
protected final def startWith(stateName: S,
stateData: D,
timeout: Option[(Long, TimeUnit)] = None) = {
applyState(State(stateName, stateData, timeout))
timeout: Timeout = None) = {
currentState = State(stateName, stateData, timeout)
}
/**
* Produce transition to other state. Return this from a state function in
* order to effect the transition.
*
* @param nextStateName state designator for the next state
* @return state transition descriptor
*/
protected final def goto(nextStateName: S): State = {
State(nextStateName, currentState.stateData)
}
/**
* Produce "empty" transition descriptor. Return this from a state function
* when no state change is to be effected.
*
* @return descriptor for staying in current state
*/
protected final def stay(): State = {
// cannot directly use currentState because of the timeout field
goto(currentState.stateName)
}
/**
* Produce change descriptor to stop this FSM actor with reason "Normal".
*/
protected final def stop(): State = {
stop(Normal)
}
/**
* Produce change descriptor to stop this FSM actor including specified reason.
*/
protected final def stop(reason: Reason): State = {
stop(reason, currentState.stateData)
}
/**
* Produce change descriptor to stop this FSM actor including specified reason.
*/
protected final def stop(reason: Reason, stateData: D): State = {
stay using stateData withStopReason(reason)
stay using stateData withStopReason (reason)
}
def whenUnhandled(stateFunction: StateFunction) = {
handleEvent = stateFunction
/**
* Schedule named timer to deliver message after given delay, possibly repeating.
* @param name identifier to be used with cancelTimer()
* @param msg message to be delivered
* @param timeout delay of first message delivery and between subsequent messages
* @param repeat send once if false, scheduleAtFixedRate if true
* @return current state descriptor
*/
protected final def setTimer(name: String, msg: AnyRef, timeout: Duration, repeat: Boolean): State = {
if (timers contains name) {
timers(name).cancel
}
val timer = Timer(name, msg, repeat)
timer.schedule(self, timeout)
timers(name) = timer
stay
}
def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = {
/**
* Cancel named timer, ensuring that the message is not subsequently delivered (no race).
* @param name of the timer to cancel
*/
protected final def cancelTimer(name: String) = {
if (timers contains name) {
timers(name).cancel
timers -= name
}
}
/**
* Inquire whether the named timer is still active. Returns true unless the
* timer does not exist, has previously been canceled or if it was a
* single-shot timer whose message was already received.
*/
protected final def timerActive_?(name: String) = timers contains name
/**
* Set state timeout explicitly. This method can safely be used from within a
* state handler.
*/
protected final def setStateTimeout(state : S, timeout : Timeout) {
stateTimeouts(state) = timeout
}
/**
* Set handler which is called upon each state transition, i.e. not when
* staying in the same state.
*/
protected final def onTransition(transitionHandler: TransitionHandler) = {
transitionEvent = transitionHandler
}
/**
* Set handler which is called upon termination of this FSM actor.
*/
protected final def onTermination(terminationHandler: PartialFunction[StopEvent[S,D], Unit]) = {
terminateEvent = terminationHandler
}
/** FSM State data and default handlers */
/**
* Set handler which is called upon reception of unhandled messages.
*/
protected final def whenUnhandled(stateFunction: StateFunction) = {
handleEvent = stateFunction orElse handleEventDefault
}
/**
* Verify existence of initial state and setup timers. This should be the
* last call within the constructor.
*/
def initialize {
makeTransition(currentState)
}
/**FSM State data and default handlers */
private var currentState: State = _
private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None
private var generation: Long = 0L
private var transitionCallBackList: List[ActorRef] = Nil
private val timers = mutable.Map[String, Timer]()
private val stateFunctions = mutable.Map[S, StateFunction]()
private def register(name: S, function: StateFunction) {
private val stateTimeouts = mutable.Map[S, Timeout]()
private def register(name: S, function: StateFunction, timeout: Timeout) {
if (stateFunctions contains name) {
stateFunctions(name) = stateFunctions(name) orElse function
stateTimeouts(name) = timeout orElse stateTimeouts(name)
} else {
stateFunctions(name) = function
stateTimeouts(name) = timeout
}
}
private var handleEvent: StateFunction = {
private val handleEventDefault: StateFunction = {
case Event(value, stateData) =>
log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName)
stay
}
private var handleEvent: StateFunction = handleEventDefault
private var terminateEvent: PartialFunction[Reason, Unit] = {
case failure@Failure(_) => log.slf4j.error("Stopping because of a {}", failure)
case reason => log.slf4j.info("Stopping because of reason: {}", reason)
private var terminateEvent: PartialFunction[StopEvent[S,D], Unit] = {
case StopEvent(Failure(cause), _, _) =>
log.slf4j.error("Stopping because of a failure with cause {}", cause)
case StopEvent(reason, _, _) => log.slf4j.info("Stopping because of reason: {}", reason)
}
private var transitionEvent: PartialFunction[Transition, Unit] = {
case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to)
private var transitionEvent: TransitionHandler = (from, to) => {
log.slf4j.debug("Transitioning from state {} to {}", from, to)
}
override final protected def receive: Receive = {
@ -101,8 +311,24 @@ trait FSM[S, D] {
if (generation == gen) {
processEvent(StateTimeout)
}
case t@Timer(name, msg, repeat) =>
if (timerActive_?(name)) {
processEvent(msg)
if (!repeat) {
timers -= name
}
}
case SubscribeTransitionCallBack(actorRef) =>
// send current state back as reference point
actorRef ! CurrentState(self, currentState.stateName)
transitionCallBackList ::= actorRef
case UnsubscribeTransitionCallBack(actorRef) =>
transitionCallBackList = transitionCallBackList.filterNot(_ == actorRef)
case value => {
timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None}
if (timeoutFuture.isDefined) {
timeoutFuture.get.cancel(true)
timeoutFuture = None
}
generation += 1
processEvent(value)
}
@ -110,7 +336,13 @@ trait FSM[S, D] {
private def processEvent(value: Any) = {
val event = Event(value, currentState.stateData)
val nextState = (stateFunctions(currentState.stateName) orElse handleEvent).apply(event)
val stateFunc = stateFunctions(currentState.stateName)
val nextState = if (stateFunc isDefinedAt event) {
stateFunc(event)
} else {
// handleEventDefault ensures that this is always defined
handleEvent(event)
}
nextState.stopReason match {
case Some(reason) => terminate(reason)
case None => makeTransition(nextState)
@ -122,7 +354,11 @@ trait FSM[S, D] {
terminate(Failure("Next state %s does not exist".format(nextState.stateName)))
} else {
if (currentState.stateName != nextState.stateName) {
transitionEvent.apply(Transition(currentState.stateName, nextState.stateName))
transitionEvent.apply(currentState.stateName, nextState.stateName)
if (!transitionCallBackList.isEmpty) {
val transition = Transition(self, currentState.stateName, nextState.stateName)
transitionCallBackList.foreach(_ ! transition)
}
}
applyState(nextState)
}
@ -130,26 +366,43 @@ trait FSM[S, D] {
private def applyState(nextState: State) = {
currentState = nextState
currentState.timeout.foreach { t =>
timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t._1, t._2))
val timeout = if (currentState.timeout.isDefined) currentState.timeout else stateTimeouts(currentState.stateName)
if (timeout.isDefined) {
val t = timeout.get
if (t.length >= 0) {
timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t.length, t.unit))
}
}
}
private def terminate(reason: Reason) = {
terminateEvent.apply(reason)
timers.foreach{ case (timer, t) => log.slf4j.info("Canceling timer {}", timer); t.cancel}
terminateEvent.apply(StopEvent(reason, currentState.stateName, currentState.stateData))
self.stop
}
case class Event[D](event: Any, stateData: D)
object Ev {
def unapply[D](e : Event[D]) : Option[Any] = Some(e.event)
}
case class Event(event: Any, stateData: D)
case class State(stateName: S, stateData: D, timeout: Timeout = None) {
case class State(stateName: S, stateData: D, timeout: Option[(Long, TimeUnit)] = None) {
def forMax(timeout: (Long, TimeUnit)): State = {
/**
* Modify state transition descriptor to include a state timeout for the
* next state. This timeout overrides any default timeout set for the next
* state.
*/
def forMax(timeout: Duration): State = {
copy(timeout = Some(timeout))
}
def replying(replyValue:Any): State = {
/**
* Send reply to sender of the current message, if available.
*
* @return this state transition descriptor
*/
def replying(replyValue: Any): State = {
self.sender match {
case Some(sender) => sender ! replyValue
case None => log.slf4j.error("Unable to send reply value {}, no sender reference to reply to", replyValue)
@ -157,16 +410,21 @@ trait FSM[S, D] {
this
}
/**
* Modify state transition descriptor with new state data. The data will be
* set when transitioning to the new state.
*/
def using(nextStateDate: D): State = {
copy(stateData = nextStateDate)
}
private[akka] var stopReason: Option[Reason] = None
private[akka] def withStopReason(reason: Reason): State = {
stopReason = Some(reason)
this
}
}
case class Transition(from: S, to: S)
case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D)
}

View file

@ -5,17 +5,149 @@
package akka.util
import java.util.concurrent.TimeUnit
import TimeUnit._
import java.lang.{Long => JLong, Double => JDouble}
object Duration {
def apply(length: Long, unit: TimeUnit) = new Duration(length, unit)
def apply(length: Long, unit: String) = new Duration(length, timeUnit(unit))
def apply(length: Long, unit: TimeUnit) : Duration = new FiniteDuration(length, unit)
def apply(length: Double, unit: TimeUnit) : Duration = fromNanos(unit.toNanos(1) * length)
def apply(length: Long, unit: String) : Duration = new FiniteDuration(length, timeUnit(unit))
def timeUnit(unit: String) = unit.toLowerCase match {
case "nanoseconds" | "nanos" | "nanosecond" | "nano" => TimeUnit.NANOSECONDS
case "microseconds" | "micros" | "microsecond" | "micro" => TimeUnit.MICROSECONDS
case "milliseconds" | "millis" | "millisecond" | "milli" => TimeUnit.MILLISECONDS
case _ => TimeUnit.SECONDS
def fromNanos(nanos : Long) : Duration = {
if (nanos % 86400000000000L == 0) {
Duration(nanos / 86400000000000L, DAYS)
} else if (nanos % 3600000000000L == 0) {
Duration(nanos / 3600000000000L, HOURS)
} else if (nanos % 60000000000L == 0) {
Duration(nanos / 60000000000L, MINUTES)
} else if (nanos % 1000000000L == 0) {
Duration(nanos / 1000000000L, SECONDS)
} else if (nanos % 1000000L == 0) {
Duration(nanos / 1000000L, MILLISECONDS)
} else if (nanos % 1000L == 0) {
Duration(nanos / 1000L, MICROSECONDS)
} else {
Duration(nanos, NANOSECONDS)
}
}
def fromNanos(nanos : Double) : Duration = fromNanos((nanos + 0.5).asInstanceOf[Long])
/**
* Construct a Duration by parsing a String. In case of a format error, a
* RuntimeException is thrown. See `unapply(String)` for more information.
*/
def apply(s : String) : Duration = unapply(s) getOrElse error("format error")
/**
* Deconstruct a Duration into length and unit if it is finite.
*/
def unapply(d : Duration) : Option[(Long, TimeUnit)] = {
if (d.finite_?) {
Some((d.length, d.unit))
} else {
None
}
}
private val RE = ("""^\s*(\d+(?:\.\d+)?)\s*"""+ // length part
"(?:"+ // units are distinguished in separate match groups
"(d|day|days)|"+
"(h|hour|hours)|"+
"(min|minute|minutes)|"+
"(s|sec|second|seconds)|"+
"(ms|milli|millis|millisecond|milliseconds)|"+
"(µs|micro|micros|microsecond|microseconds)|"+
"(ns|nano|nanos|nanosecond|nanoseconds)"+
""")\s*$""").r // close the non-capturing group
private val REinf = """^\s*Inf\s*$""".r
private val REminf = """^\s*(?:-\s*|Minus)Inf\s*""".r
/**
* Parse String, return None if no match. Format is `"<length><unit>"`, where
* whitespace is allowed before, between and after the parts. Infinities are
* designated by `"Inf"` and `"-Inf"` or `"MinusInf"`.
*/
def unapply(s : String) : Option[Duration] = s match {
case RE(length, d, h, m, s, ms, mus, ns) =>
if ( d ne null) Some(Duration(JDouble.parseDouble(length), DAYS)) else
if ( h ne null) Some(Duration(JDouble.parseDouble(length), HOURS)) else
if ( m ne null) Some(Duration(JDouble.parseDouble(length), MINUTES)) else
if ( s ne null) Some(Duration(JDouble.parseDouble(length), SECONDS)) else
if ( ms ne null) Some(Duration(JDouble.parseDouble(length), MILLISECONDS)) else
if (mus ne null) Some(Duration(JDouble.parseDouble(length), MICROSECONDS)) else
if ( ns ne null) Some(Duration(JDouble.parseDouble(length), NANOSECONDS)) else
error("made some error in regex (should not be possible)")
case REinf() => Some(Inf)
case REminf() => Some(MinusInf)
case _ => None
}
/**
* Parse TimeUnit from string representation.
*/
def timeUnit(unit: String) = unit.toLowerCase match {
case "d" | "day" | "days" => DAYS
case "h" | "hour" | "hours" => HOURS
case "min" | "minute" | "minutes" => MINUTES
case "s" | "sec" | "second" | "seconds" => SECONDS
case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS
case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS
case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS
}
trait Infinite {
this : Duration =>
override def equals(other : Any) = false
def +(other : Duration) : Duration = this
def -(other : Duration) : Duration = this
def *(other : Double) : Duration = this
def /(other : Double) : Duration = this
def finite_? = false
def length : Long = throw new IllegalArgumentException("length not allowed on infinite Durations")
def unit : TimeUnit = throw new IllegalArgumentException("unit not allowed on infinite Durations")
def toNanos : Long = throw new IllegalArgumentException("toNanos not allowed on infinite Durations")
def toMicros : Long = throw new IllegalArgumentException("toMicros not allowed on infinite Durations")
def toMillis : Long = throw new IllegalArgumentException("toMillis not allowed on infinite Durations")
def toSeconds : Long = throw new IllegalArgumentException("toSeconds not allowed on infinite Durations")
def toMinutes : Long = throw new IllegalArgumentException("toMinutes not allowed on infinite Durations")
def toHours : Long = throw new IllegalArgumentException("toHours not allowed on infinite Durations")
def toDays : Long = throw new IllegalArgumentException("toDays not allowed on infinite Durations")
def toUnit(unit : TimeUnit) : Double = throw new IllegalArgumentException("toUnit not allowed on infinite Durations")
def printHMS = toString
}
/**
* Infinite duration: greater than any other and not equal to any other,
* including itself.
*/
object Inf extends Duration with Infinite {
override def toString = "Duration.Inf"
def >(other : Duration) = true
def >=(other : Duration) = true
def <(other : Duration) = false
def <=(other : Duration) = false
def unary_- : Duration = MinusInf
}
/**
* Infinite negative duration: lesser than any other and not equal to any other,
* including itself.
*/
object MinusInf extends Duration with Infinite {
override def toString = "Duration.MinusInf"
def >(other : Duration) = false
def >=(other : Duration) = false
def <(other : Duration) = true
def <=(other : Duration) = true
def unary_- : Duration = Inf
}
}
/**
@ -24,11 +156,11 @@ object Duration {
* <p/>
* Examples of usage from Java:
* <pre>
* import akka.util.Duration;
* import akka.util.FiniteDuration;
* import java.util.concurrent.TimeUnit;
*
* Duration duration = new Duration(100, TimeUnit.MILLISECONDS);
* Duration duration = new Duration(5, "seconds");
* Duration duration = new FiniteDuration(100, MILLISECONDS);
* Duration duration = new FiniteDuration(5, "seconds");
*
* duration.toNanos();
* </pre>
@ -39,70 +171,255 @@ object Duration {
* import akka.util.Duration
* import java.util.concurrent.TimeUnit
*
* val duration = Duration(100, TimeUnit.MILLISECONDS)
* val duration = Duration(100, MILLISECONDS)
* val duration = Duration(100, "millis")
*
* duration.toNanos
* duration < 1.second
* duration <= Duration.Inf
* </pre>
*
* <p/>
* Implicits are also provided for Int and Long. Example usage:
* Implicits are also provided for Int, Long and Double. Example usage:
* <pre>
* import akka.util.duration._
*
* val duration = 100.millis
* val duration = 100 millis
* </pre>
*
* Extractors, parsing and arithmetic are also included:
* <pre>
* val d = Duration("1.2 µs")
* val Duration(length, unit) = 5 millis
* val d2 = d * 2.5
* val d3 = d2 + 1.millisecond
* </pre>
*/
class Duration(val length: Long, val unit: TimeUnit) {
trait Duration {
def length : Long
def unit : TimeUnit
def toNanos : Long
def toMicros : Long
def toMillis : Long
def toSeconds : Long
def toMinutes : Long
def toHours : Long
def toDays : Long
def toUnit(unit : TimeUnit) : Double
def printHMS : String
def <(other : Duration) : Boolean
def <=(other : Duration) : Boolean
def >(other : Duration) : Boolean
def >=(other : Duration) : Boolean
def +(other : Duration) : Duration
def -(other : Duration) : Duration
def *(factor : Double) : Duration
def /(factor : Double) : Duration
def unary_- : Duration
def finite_? : Boolean
}
class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration {
import Duration._
def this(length: Long, unit: String) = this(length, Duration.timeUnit(unit))
def toNanos = unit.toNanos(length)
def toMicros = unit.toMicros(length)
def toMillis = unit.toMillis(length)
def toSeconds = unit.toSeconds(length)
override def toString = "Duration(" + length + ", " + unit + ")"
def toMinutes = unit.toMinutes(length)
def toHours = unit.toHours(length)
def toDays = unit.toDays(length)
def toUnit(u : TimeUnit) = long2double(toNanos) / NANOSECONDS.convert(1, u)
override def toString = this match {
case Duration(1, DAYS) => "1 day"
case Duration(x, DAYS) => x+" days"
case Duration(1, HOURS) => "1 hour"
case Duration(x, HOURS) => x+" hours"
case Duration(1, MINUTES) => "1 minute"
case Duration(x, MINUTES) => x+" minutes"
case Duration(1, SECONDS) => "1 second"
case Duration(x, SECONDS) => x+" seconds"
case Duration(1, MILLISECONDS) => "1 millisecond"
case Duration(x, MILLISECONDS) => x+" milliseconds"
case Duration(1, MICROSECONDS) => "1 microsecond"
case Duration(x, MICROSECONDS) => x+" microseconds"
case Duration(1, NANOSECONDS) => "1 nanosecond"
case Duration(x, NANOSECONDS) => x+" nanoseconds"
}
def printHMS = "%02d:%02d:%06.3f".format(toHours, toMinutes % 60, toMillis / 1000. % 60)
def <(other : Duration) = {
if (other.finite_?) {
toNanos < other.asInstanceOf[FiniteDuration].toNanos
} else {
other > this
}
}
def <=(other : Duration) = {
if (other.finite_?) {
toNanos <= other.asInstanceOf[FiniteDuration].toNanos
} else {
other >= this
}
}
def >(other : Duration) = {
if (other.finite_?) {
toNanos > other.asInstanceOf[FiniteDuration].toNanos
} else {
other < this
}
}
def >=(other : Duration) = {
if (other.finite_?) {
toNanos >= other.asInstanceOf[FiniteDuration].toNanos
} else {
other <= this
}
}
def +(other : Duration) = {
if (!other.finite_?) {
other
} else {
val nanos = toNanos + other.asInstanceOf[FiniteDuration].toNanos
fromNanos(nanos)
}
}
def -(other : Duration) = {
if (!other.finite_?) {
other
} else {
val nanos = toNanos - other.asInstanceOf[FiniteDuration].toNanos
fromNanos(nanos)
}
}
def *(factor : Double) = fromNanos(long2double(toNanos) * factor)
def /(factor : Double) = fromNanos(long2double(toNanos) / factor)
def unary_- = Duration(-length, unit)
def finite_? = true
override def equals(other : Any) =
other.isInstanceOf[FiniteDuration] &&
toNanos == other.asInstanceOf[FiniteDuration].toNanos
override def hashCode = toNanos.asInstanceOf[Int]
}
package object duration {
implicit def intToDurationInt(n: Int) = new DurationInt(n)
implicit def longToDurationLong(n: Long) = new DurationLong(n)
implicit def doubleToDurationDouble(d: Double) = new DurationDouble(d)
implicit def pairIntToDuration(p : (Int, TimeUnit)) = Duration(p._1, p._2)
implicit def pairLongToDuration(p : (Long, TimeUnit)) = Duration(p._1, p._2)
implicit def durationToPair(d : Duration) = (d.length, d.unit)
implicit def intMult(i : Int) = new {
def *(d : Duration) = d * i
}
implicit def longMult(l : Long) = new {
def *(d : Duration) = d * l
}
implicit def doubleMult(f : Double) = new {
def *(d : Duration) = d * f
}
}
class DurationInt(n: Int) {
def nanoseconds = Duration(n, TimeUnit.NANOSECONDS)
def nanos = Duration(n, TimeUnit.NANOSECONDS)
def nanosecond = Duration(n, TimeUnit.NANOSECONDS)
def nano = Duration(n, TimeUnit.NANOSECONDS)
def nanoseconds = Duration(n, NANOSECONDS)
def nanos = Duration(n, NANOSECONDS)
def nanosecond = Duration(n, NANOSECONDS)
def nano = Duration(n, NANOSECONDS)
def microseconds = Duration(n, TimeUnit.MICROSECONDS)
def micros = Duration(n, TimeUnit.MICROSECONDS)
def microsecond = Duration(n, TimeUnit.MICROSECONDS)
def micro = Duration(n, TimeUnit.MICROSECONDS)
def microseconds = Duration(n, MICROSECONDS)
def micros = Duration(n, MICROSECONDS)
def microsecond = Duration(n, MICROSECONDS)
def micro = Duration(n, MICROSECONDS)
def milliseconds = Duration(n, TimeUnit.MILLISECONDS)
def millis = Duration(n, TimeUnit.MILLISECONDS)
def millisecond = Duration(n, TimeUnit.MILLISECONDS)
def milli = Duration(n, TimeUnit.MILLISECONDS)
def milliseconds = Duration(n, MILLISECONDS)
def millis = Duration(n, MILLISECONDS)
def millisecond = Duration(n, MILLISECONDS)
def milli = Duration(n, MILLISECONDS)
def seconds = Duration(n, TimeUnit.SECONDS)
def second = Duration(n, TimeUnit.SECONDS)
def seconds = Duration(n, SECONDS)
def second = Duration(n, SECONDS)
def minutes = Duration(n, MINUTES)
def minute = Duration(n, MINUTES)
def hours = Duration(n, HOURS)
def hour = Duration(n, HOURS)
def days = Duration(n, DAYS)
def day = Duration(n, DAYS)
}
class DurationLong(n: Long) {
def nanoseconds = Duration(n, TimeUnit.NANOSECONDS)
def nanos = Duration(n, TimeUnit.NANOSECONDS)
def nanosecond = Duration(n, TimeUnit.NANOSECONDS)
def nano = Duration(n, TimeUnit.NANOSECONDS)
def nanoseconds = Duration(n, NANOSECONDS)
def nanos = Duration(n, NANOSECONDS)
def nanosecond = Duration(n, NANOSECONDS)
def nano = Duration(n, NANOSECONDS)
def microseconds = Duration(n, TimeUnit.MICROSECONDS)
def micros = Duration(n, TimeUnit.MICROSECONDS)
def microsecond = Duration(n, TimeUnit.MICROSECONDS)
def micro = Duration(n, TimeUnit.MICROSECONDS)
def microseconds = Duration(n, MICROSECONDS)
def micros = Duration(n, MICROSECONDS)
def microsecond = Duration(n, MICROSECONDS)
def micro = Duration(n, MICROSECONDS)
def milliseconds = Duration(n, TimeUnit.MILLISECONDS)
def millis = Duration(n, TimeUnit.MILLISECONDS)
def millisecond = Duration(n, TimeUnit.MILLISECONDS)
def milli = Duration(n, TimeUnit.MILLISECONDS)
def milliseconds = Duration(n, MILLISECONDS)
def millis = Duration(n, MILLISECONDS)
def millisecond = Duration(n, MILLISECONDS)
def milli = Duration(n, MILLISECONDS)
def seconds = Duration(n, TimeUnit.SECONDS)
def second = Duration(n, TimeUnit.SECONDS)
def seconds = Duration(n, SECONDS)
def second = Duration(n, SECONDS)
def minutes = Duration(n, MINUTES)
def minute = Duration(n, MINUTES)
def hours = Duration(n, HOURS)
def hour = Duration(n, HOURS)
def days = Duration(n, DAYS)
def day = Duration(n, DAYS)
}
class DurationDouble(d: Double) {
def nanoseconds = Duration(d, NANOSECONDS)
def nanos = Duration(d, NANOSECONDS)
def nanosecond = Duration(d, NANOSECONDS)
def nano = Duration(d, NANOSECONDS)
def microseconds = Duration(d, MICROSECONDS)
def micros = Duration(d, MICROSECONDS)
def microsecond = Duration(d, MICROSECONDS)
def micro = Duration(d, MICROSECONDS)
def milliseconds = Duration(d, MILLISECONDS)
def millis = Duration(d, MILLISECONDS)
def millisecond = Duration(d, MILLISECONDS)
def milli = Duration(d, MILLISECONDS)
def seconds = Duration(d, SECONDS)
def second = Duration(d, SECONDS)
def minutes = Duration(d, MINUTES)
def minute = Duration(d, MINUTES)
def hours = Duration(d, HOURS)
def hour = Duration(d, HOURS)
def days = Duration(d, DAYS)
def day = Duration(d, DAYS)
}

View file

@ -0,0 +1,434 @@
package akka.util
import akka.actor.{Actor, FSM}
import Actor._
import duration._
import java.util.concurrent.{BlockingDeque, LinkedBlockingDeque, TimeUnit}
import scala.annotation.tailrec
object TestActor {
type Ignore = Option[PartialFunction[AnyRef, Boolean]]
case class SetTimeout(d : Duration)
case class SetIgnore(i : Ignore)
}
class TestActor(queue : BlockingDeque[AnyRef]) extends Actor with FSM[Int, TestActor.Ignore] {
import FSM._
import TestActor._
startWith(0, None)
when(0, stateTimeout = 5 seconds) {
case Ev(SetTimeout(d)) =>
setStateTimeout(0, if (d.finite_?) d else None)
stay
case Ev(SetIgnore(ign)) => stay using ign
case Ev(StateTimeout) =>
stop
case Event(x : AnyRef, ign) =>
val ignore = ign map (z => if (z isDefinedAt x) z(x) else false) getOrElse false
if (!ignore) {
queue.offerLast(x)
}
stay
}
initialize
}
/**
* Test kit for testing actors. Inheriting from this trait enables reception of
* replies from actors, which are queued by an internal actor and can be
* examined using the `expectMsg...` methods. Assertions and bounds concerning
* timing are available in the form of `within` blocks.
*
* <pre>
* class Test extends TestKit {
* val test = actorOf[SomeActor].start
*
* within (1 second) {
* test ! SomeWork
* expectMsg(Result1) // bounded to 1 second
* expectMsg(Result2) // bounded to the remainder of the 1 second
* }
* }
* </pre>
*
* Beware of two points:
*
* - the internal test actor needs to be stopped, either explicitly using
* `stopTestActor` or implicitly by using its internal inactivity timeout,
* see `setTestActorTimeout`
* - this trait is not thread-safe (only one actor with one queue, one stack
* of `within` blocks); it is expected that the code is executed from a
* constructor as shown above, which makes this a non-issue, otherwise take
* care not to run tests within a single test class instance in parallel.
*
* @author Roland Kuhn
* @since 1.1
*/
trait TestKit {
private val queue = new LinkedBlockingDeque[AnyRef]()
/**
* ActorRef of the test actor. Access is provided to enable e.g.
* registration as message target.
*/
protected val testActor = actorOf(new TestActor(queue)).start
/**
* Implicit sender reference so that replies are possible for messages sent
* from the test class.
*/
protected implicit val senderOption = Some(testActor)
private var end : Duration = Duration.Inf
/*
* THIS IS A HACK: expectNoMsg and receiveWhile are bounded by `end`, but
* running them should not trigger an AssertionError, so mark their end
* time here and do not fail at the end of `within` if that time is not
* long gone.
*/
private var lastSoftTimeout : Duration = now - 5.millis
/**
* Stop test actor. Should be done at the end of the test unless relying on
* test actor timeout.
*/
def stopTestActor { testActor.stop }
/**
* Set test actor timeout. By default, the test actor shuts itself down
* after 5 seconds of inactivity. Set this to Duration.Inf to disable this
* behavior, but make sure that someone will then call `stopTestActor`,
* unless you want to leak actors, e.g. wrap test in
*
* <pre>
* try {
* ...
* } finally { stopTestActor }
* </pre>
*/
def setTestActorTimeout(d : Duration) { testActor ! TestActor.SetTimeout(d) }
/**
* Ignore all messages in the test actor for which the given partial
* function returns true.
*/
def ignoreMsg(f : PartialFunction[AnyRef, Boolean]) { testActor ! TestActor.SetIgnore(Some(f)) }
/**
* Stop ignoring messages in the test actor.
*/
def ignoreNoMsg { testActor ! TestActor.SetIgnore(None) }
/**
* Obtain current time (`System.currentTimeMillis`) as Duration.
*/
def now : Duration = System.nanoTime.nanos
/**
* Obtain time remaining for execution of the innermost enclosing `within` block.
*/
def remaining : Duration = end - now
/**
* Execute code block while bounding its execution time between `min` and
* `max`. `within` blocks may be nested. All methods in this trait which
* take maximum wait times are available in a version which implicitly uses
* the remaining time governed by the innermost enclosing `within` block.
*
* <pre>
* val ret = within(50 millis) {
* test ! "ping"
* expectMsgClass(classOf[String])
* }
* </pre>
*/
def within[T](min : Duration, max : Duration)(f : => T) : T = {
val start = now
val rem = end - start
assert (rem >= min, "required min time "+min+" not possible, only "+format(min.unit, rem)+" left")
val max_diff = if (max < rem) max else rem
val prev_end = end
end = start + max_diff
val ret = f
val diff = now - start
assert (min <= diff, "block took "+format(min.unit, diff)+", should at least have been "+min)
/*
* caution: HACK AHEAD
*/
if (now - lastSoftTimeout > 5.millis) {
assert (diff <= max_diff, "block took "+format(max.unit, diff)+", exceeding "+format(max.unit, max_diff))
} else {
lastSoftTimeout -= 5.millis
}
end = prev_end
ret
}
/**
* Same as calling `within(0 seconds, max)(f)`.
*/
def within[T](max : Duration)(f : => T) : T = within(0 seconds, max)(f)
/**
* Same as `expectMsg`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsg(obj : Any) : AnyRef = expectMsg(remaining, obj)
/**
* Receive one message from the test actor and assert that it equals the
* given object. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsg(max : Duration, obj : Any) : AnyRef = {
val o = receiveOne(max)
assert (o ne null, "timeout during expectMsg")
assert (obj == o, "expected "+obj+", found "+o)
o
}
/**
* Same as `expectMsg`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsg[T](f : PartialFunction[Any, T]) : T = expectMsg(remaining)(f)
/**
* Receive one message from the test actor and assert that the given
* partial function accepts it. Wait time is bounded by the given duration,
* with an AssertionFailure being thrown in case of timeout.
*
* Use this variant to implement more complicated or conditional
* processing.
*
* @return the received object as transformed by the partial function
*/
def expectMsg[T](max : Duration)(f : PartialFunction[Any, T]) : T = {
val o = receiveOne(max)
assert (o ne null, "timeout during expectMsg")
assert (f.isDefinedAt(o), "does not match: "+o)
f(o)
}
/**
* Same as `expectMsgClass`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgClass[C](c : Class[C]) : C = expectMsgClass(remaining, c)
/**
* Receive one message from the test actor and assert that it conforms to
* the given class. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgClass[C](max : Duration, c : Class[C]) : C = {
val o = receiveOne(max)
assert (o ne null, "timeout during expectMsgClass")
assert (c isInstance o, "expected "+c+", found "+o.getClass)
o.asInstanceOf[C]
}
/**
* Same as `expectMsgAnyOf`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgAnyOf(obj : Any*) : AnyRef = expectMsgAnyOf(remaining, obj : _*)
/**
* Receive one message from the test actor and assert that it equals one of
* the given objects. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgAnyOf(max : Duration, obj : Any*) : AnyRef = {
val o = receiveOne(max)
assert (o ne null, "timeout during expectMsgAnyOf")
assert (obj exists (_ == o), "found unexpected "+o)
o
}
/**
* Same as `expectMsgAnyClassOf`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgAnyClassOf(obj : Class[_]*) : AnyRef = expectMsgAnyClassOf(remaining, obj : _*)
/**
* Receive one message from the test actor and assert that it conforms to
* one of the given classes. Wait time is bounded by the given duration,
* with an AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgAnyClassOf(max : Duration, obj : Class[_]*) : AnyRef = {
val o = receiveOne(max)
assert (o ne null, "timeout during expectMsgAnyClassOf")
assert (obj exists (_ isInstance o), "found unexpected "+o)
o
}
/**
* Same as `expectMsgAllOf`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgAllOf(obj : Any*) { expectMsgAllOf(remaining, obj : _*) }
/**
* Receive a number of messages from the test actor matching the given
* number of objects and assert that for each given object one is received
* which equals it. This construct is useful when the order in which the
* objects are received is not fixed. Wait time is bounded by the given
* duration, with an AssertionFailure being thrown in case of timeout.
*
* <pre>
* within(1 second) {
* dispatcher ! SomeWork1()
* dispatcher ! SomeWork2()
* expectMsgAllOf(Result1(), Result2())
* }
* </pre>
*/
def expectMsgAllOf(max : Duration, obj : Any*) {
val recv = receiveN(obj.size, now + max)
assert (obj forall (x => recv exists (x == _)), "not found all")
}
/**
* Same as `expectMsgAllClassOf`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgAllClassOf(obj : Class[_]*) { expectMsgAllClassOf(remaining, obj : _*) }
/**
* Receive a number of messages from the test actor matching the given
* number of classes and assert that for each given class one is received
* which is of that class (equality, not conformance). This construct is
* useful when the order in which the objects are received is not fixed.
* Wait time is bounded by the given duration, with an AssertionFailure
* being thrown in case of timeout.
*/
def expectMsgAllClassOf(max : Duration, obj : Class[_]*) {
val recv = receiveN(obj.size, now + max)
assert (obj forall (x => recv exists (_.getClass eq x)), "not found all")
}
/**
* Same as `expectMsgAllConformingOf`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectMsgAllConformingOf(obj : Class[_]*) { expectMsgAllClassOf(remaining, obj : _*) }
/**
* Receive a number of messages from the test actor matching the given
* number of classes and assert that for each given class one is received
* which conforms to that class. This construct is useful when the order in
* which the objects are received is not fixed. Wait time is bounded by
* the given duration, with an AssertionFailure being thrown in case of
* timeout.
*
* Beware that one object may satisfy all given class constraints, which
* may be counter-intuitive.
*/
def expectMsgAllConformingOf(max : Duration, obj : Class[_]*) {
val recv = receiveN(obj.size, now + max)
assert (obj forall (x => recv exists (x isInstance _)), "not found all")
}
/**
* Same as `expectNoMsg`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def expectNoMsg { expectNoMsg(remaining) }
/**
* Assert that no message is received for the specified time.
*/
def expectNoMsg(max : Duration) {
val o = receiveOne(max)
assert (o eq null, "received unexpected message "+o)
lastSoftTimeout = now
}
/**
* Same as `receiveWhile`, but takes the maximum wait time from the innermost
* enclosing `within` block.
*/
def receiveWhile[T](f : PartialFunction[AnyRef, T]) : Seq[T] = receiveWhile(remaining)(f)
/**
* Receive a series of messages as long as the given partial function
* accepts them or the idle timeout is met or the overall maximum duration
* is elapsed. Returns the sequence of messages.
*
* Beware that the maximum duration is not implicitly bounded by or taken
* from the innermost enclosing `within` block, as it is not an error to
* hit the `max` duration in this case.
*
* One possible use of this method is for testing whether messages of
* certain characteristics are generated at a certain rate:
*
* <pre>
* test ! ScheduleTicks(100 millis)
* val series = receiveWhile(750 millis) {
* case Tick(count) => count
* }
* assert(series == (1 to 7).toList)
* </pre>
*/
def receiveWhile[T](max : Duration)(f : PartialFunction[AnyRef, T]) : Seq[T] = {
val stop = now + max
@tailrec def doit(acc : List[T]) : List[T] = {
receiveOne(stop - now) match {
case null =>
acc.reverse
case o if (f isDefinedAt o) =>
doit(f(o) :: acc)
case o =>
queue.offerFirst(o)
acc.reverse
}
}
val ret = doit(Nil)
lastSoftTimeout = now
ret
}
private def receiveN(n : Int, stop : Duration) : Seq[AnyRef] = {
for { x <- 1 to n } yield {
val timeout = stop - now
val o = receiveOne(timeout)
assert (o ne null, "timeout while expecting "+n+" messages")
o
}
}
private def receiveOne(max : Duration) : AnyRef = {
if (max == 0.seconds) {
queue.pollFirst
} else if (max.finite_?) {
queue.pollFirst(max.length, max.unit)
} else {
queue.takeFirst
}
}
private def format(u : TimeUnit, d : Duration) = "%.3f %s".format(d.toUnit(u), u.toString.toLowerCase)
}
// vim: set ts=2 sw=2 et:

View file

@ -6,19 +6,24 @@ package akka.actor
import org.scalatest.junit.JUnitSuite
import org.junit.Test
import FSM._
import org.multiverse.api.latches.StandardLatch
import java.util.concurrent.TimeUnit
import akka.util.duration._
object FSMActorSpec {
import FSM._
val unlockedLatch = new StandardLatch
val lockedLatch = new StandardLatch
val unhandledLatch = new StandardLatch
val terminatedLatch = new StandardLatch
val transitionLatch = new StandardLatch
val initialStateLatch = new StandardLatch
val transitionCallBackLatch = new StandardLatch
sealed trait LockState
case object Locked extends LockState
@ -26,11 +31,8 @@ object FSMActorSpec {
class Lock(code: String, timeout: (Long, TimeUnit)) extends Actor with FSM[LockState, CodeState] {
notifying {
case Transition(Locked, Open) => transitionLatch.open
case Transition(_, _) => ()
}
startWith(Locked, CodeState("", code))
when(Locked) {
case Event(digit: Char, CodeState(soFar, code)) => {
soFar + digit match {
@ -47,7 +49,7 @@ object FSMActorSpec {
}
}
case Event("hello", _) => stay replying "world"
case Event("bye", _) => stop
case Event("bye", _) => stop(Shutdown)
}
when(Open) {
@ -57,8 +59,6 @@ object FSMActorSpec {
}
}
startWith(Locked, CodeState("", code))
whenUnhandled {
case Event(_, stateData) => {
log.slf4j.info("Unhandled")
@ -67,10 +67,21 @@ object FSMActorSpec {
}
}
onTermination {
case reason => terminatedLatch.open
onTransition(transitionHandler)
def transitionHandler(from: LockState, to: LockState) = {
if (from == Locked && to == Open) transitionLatch.open
}
onTermination {
case StopEvent(Shutdown, Locked, _) =>
// stop is called from lockstate with shutdown as reason...
terminatedLatch.open
}
// initialize the lock
initialize
private def doLock() {
log.slf4j.info("Locked")
lockedLatch.open
@ -88,12 +99,21 @@ object FSMActorSpec {
class FSMActorSpec extends JUnitSuite {
import FSMActorSpec._
@Test
def unlockTheLock = {
// lock that locked after being open for 1 sec
val lock = Actor.actorOf(new Lock("33221", (1, TimeUnit.SECONDS))).start
val transitionTester = Actor.actorOf(new Actor { def receive = {
case Transition(_, _, _) => transitionCallBackLatch.open
case CurrentState(_, Locked) => initialStateLatch.open
}}).start
lock ! SubscribeTransitionCallBack(transitionTester)
assert(initialStateLatch.tryAwait(1, TimeUnit.SECONDS))
lock ! '3'
lock ! '3'
lock ! '2'
@ -102,8 +122,10 @@ class FSMActorSpec extends JUnitSuite {
assert(unlockedLatch.tryAwait(1, TimeUnit.SECONDS))
assert(transitionLatch.tryAwait(1, TimeUnit.SECONDS))
assert(transitionCallBackLatch.tryAwait(1, TimeUnit.SECONDS))
assert(lockedLatch.tryAwait(2, TimeUnit.SECONDS))
lock ! "not_handled"
assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS))

View file

@ -0,0 +1,142 @@
package akka.actor
import akka.util.TestKit
import akka.util.duration._
import org.scalatest.WordSpec
import org.scalatest.matchers.MustMatchers
class FSMTimingSpec
extends WordSpec
with MustMatchers
with TestKit {
import FSMTimingSpec._
import FSM._
val fsm = Actor.actorOf(new StateMachine(testActor)).start
fsm ! SubscribeTransitionCallBack(testActor)
expectMsg(100 millis, CurrentState(fsm, Initial))
ignoreMsg {
case Transition(_, Initial, _) => true
}
"A Finite State Machine" must {
"receive StateTimeout" in {
within (50 millis, 150 millis) {
fsm ! TestStateTimeout
expectMsg(Transition(fsm, TestStateTimeout, Initial))
expectNoMsg
}
}
"receive single-shot timer" in {
within (50 millis, 150 millis) {
fsm ! TestSingleTimer
expectMsg(Tick)
expectMsg(Transition(fsm, TestSingleTimer, Initial))
expectNoMsg
}
}
"receive and cancel a repeated timer" in {
fsm ! TestRepeatedTimer
val seq = receiveWhile(550 millis) {
case Tick => Tick
}
seq must have length (5)
within(250 millis) {
fsm ! Cancel
expectMsg(Transition(fsm, TestRepeatedTimer, Initial))
expectNoMsg
}
}
"notify unhandled messages" in {
fsm ! TestUnhandled
within(100 millis) {
fsm ! Tick
expectNoMsg
}
within(100 millis) {
fsm ! SetHandler
fsm ! Tick
expectMsg(Unhandled(Tick))
expectNoMsg
}
within(100 millis) {
fsm ! Unhandled("test")
expectNoMsg
}
within(100 millis) {
fsm ! Cancel
expectMsg(Transition(fsm, TestUnhandled, Initial))
}
}
}
}
object FSMTimingSpec {
trait State
case object Initial extends State
case object TestStateTimeout extends State
case object TestSingleTimer extends State
case object TestRepeatedTimer extends State
case object TestUnhandled extends State
case object Tick
case object Cancel
case object SetHandler
case class Unhandled(msg : AnyRef)
class StateMachine(tester : ActorRef) extends Actor with FSM[State, Unit] {
import FSM._
startWith(Initial, ())
when(Initial) {
case Ev(TestSingleTimer) =>
setTimer("tester", Tick, 100 millis, false)
goto(TestSingleTimer)
case Ev(TestRepeatedTimer) =>
setTimer("tester", Tick, 100 millis, true)
goto(TestRepeatedTimer)
case Ev(x : FSMTimingSpec.State) => goto(x)
}
when(TestStateTimeout, stateTimeout = 100 millis) {
case Ev(StateTimeout) => goto(Initial)
}
when(TestSingleTimer) {
case Ev(Tick) =>
tester ! Tick
goto(Initial)
}
when(TestRepeatedTimer) {
case Ev(Tick) =>
tester ! Tick
stay
case Ev(Cancel) =>
cancelTimer("tester")
goto(Initial)
}
when(TestUnhandled) {
case Ev(SetHandler) =>
whenUnhandled {
case Ev(Tick) =>
tester ! Unhandled(Tick)
stay
}
stay
case Ev(Cancel) =>
goto(Initial)
}
}
}
// vim: set ts=2 sw=2 et:

View file

@ -0,0 +1,75 @@
package sample.fsm.buncher
import scala.reflect.ClassManifest
import akka.util.Duration
import akka.actor.{FSM, Actor}
/*
* generic typed object buncher.
*
* To instantiate it, use the factory method like so:
* Buncher(100, 500)(x : List[AnyRef] => x foreach println)
* which will yield a fully functional and started ActorRef.
* The type of messages allowed is strongly typed to match the
* supplied processing method; other messages are discarded (and
* possibly logged).
*/
object Buncher {
trait State
case object Idle extends State
case object Active extends State
case object Flush // send out current queue immediately
case object Stop // poison pill
case class Data[A](start : Long, xs : List[A])
def apply[A : Manifest](singleTimeout : Duration,
multiTimeout : Duration)(f : List[A] => Unit) =
Actor.actorOf(new Buncher[A](singleTimeout, multiTimeout).deliver(f))
}
class Buncher[A : Manifest] private (val singleTimeout : Duration, val multiTimeout : Duration)
extends Actor with FSM[Buncher.State, Buncher.Data[A]] {
import Buncher._
import FSM._
private val manifestA = manifest[A]
private var send : List[A] => Unit = _
private def deliver(f : List[A] => Unit) = { send = f; this }
private def now = System.currentTimeMillis
private def check(m : AnyRef) = ClassManifest.fromClass(m.getClass) <:< manifestA
startWith(Idle, Data(0, Nil))
when(Idle) {
case Event(m : AnyRef, _) if check(m) =>
goto(Active) using Data(now, m.asInstanceOf[A] :: Nil)
case Event(Flush, _) => stay
case Event(Stop, _) => stop
}
when(Active, stateTimeout = Some(singleTimeout)) {
case Event(m : AnyRef, Data(start, xs)) if check(m) =>
val l = m.asInstanceOf[A] :: xs
if (now - start > multiTimeout.toMillis) {
send(l.reverse)
goto(Idle) using Data(0, Nil)
} else {
stay using Data(start, l)
}
case Event(StateTimeout, Data(_, xs)) =>
send(xs.reverse)
goto(Idle) using Data(0, Nil)
case Event(Flush, Data(_, xs)) =>
send(xs.reverse)
goto(Idle) using Data(0, Nil)
case Event(Stop, Data(_, xs)) =>
send(xs.reverse)
stop
}
initialize
}

View file

@ -3,8 +3,8 @@ package sample.fsm.dining.fsm
import akka.actor.{ActorRef, Actor, FSM}
import akka.actor.FSM._
import Actor._
import java.util.concurrent.TimeUnit
import TimeUnit._
import akka.util.Duration
import akka.util.duration._
/*
* Some messages for the chopstick
@ -33,6 +33,9 @@ case class TakenBy(hakker: Option[ActorRef])
class Chopstick(name: String) extends Actor with FSM[ChopstickState, TakenBy] {
self.id = name
// A chopstick begins its existence as available and taken by no one
startWith(Available, TakenBy(None))
// When a chopstick is available, it can be taken by a some hakker
when(Available) {
case Event(Take, _) =>
@ -49,8 +52,8 @@ class Chopstick(name: String) extends Actor with FSM[ChopstickState, TakenBy] {
goto(Available) using TakenBy(None)
}
// A chopstick begins its existence as available and taken by no one
startWith(Available, TakenBy(None))
// Initialze the chopstick
initialize
}
/**
@ -81,10 +84,13 @@ case class TakenChopsticks(left: Option[ActorRef], right: Option[ActorRef])
class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor with FSM[FSMHakkerState, TakenChopsticks] {
self.id = name
//All hakkers start waiting
startWith(Waiting, TakenChopsticks(None, None))
when(Waiting) {
case Event(Think, _) =>
log.info("%s starts to think", name)
startThinking(5, SECONDS)
startThinking(5 seconds)
}
//When a hakker is thinking it can become hungry
@ -118,12 +124,12 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit
case Event(Busy(chopstick), TakenChopsticks(leftOption, rightOption)) =>
leftOption.foreach(_ ! Put)
rightOption.foreach(_ ! Put)
startThinking(10, MILLISECONDS)
startThinking(10 milliseconds)
}
private def startEating(left: ActorRef, right: ActorRef): State = {
log.info("%s has picked up %s and %s, and starts to eat", name, left.id, right.id)
goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5, SECONDS)
goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5 seconds)
}
// When the results of the other grab comes back,
@ -132,9 +138,9 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit
when(FirstChopstickDenied) {
case Event(Taken(secondChopstick), _) =>
secondChopstick ! Put
startThinking(10, MILLISECONDS)
startThinking(10 milliseconds)
case Event(Busy(chopstick), _) =>
startThinking(10, MILLISECONDS)
startThinking(10 milliseconds)
}
// When a hakker is eating, he can decide to start to think,
@ -144,15 +150,15 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit
log.info("%s puts down his chopsticks and starts to think", name)
left ! Put
right ! Put
startThinking(5, SECONDS)
startThinking(5 seconds)
}
private def startThinking(period: Int, timeUnit: TimeUnit): State = {
goto(Thinking) using TakenChopsticks(None, None) forMax (period, timeUnit)
}
// Initialize the hakker
initialize
//All hakkers start waiting
startWith(Waiting, TakenChopsticks(None, None))
private def startThinking(duration: Duration): State = {
goto(Thinking) using TakenChopsticks(None, None) forMax duration
}
}
/*

View file

@ -5,7 +5,7 @@ import akka.transactor.Atomically;
import akka.actor.ActorRef;
import akka.actor.UntypedActor;
import akka.stm.*;
import akka.util.Duration;
import akka.util.FiniteDuration;
import org.multiverse.api.StmUtils;
@ -17,7 +17,7 @@ public class UntypedCoordinatedCounter extends UntypedActor {
private String name;
private Ref<Integer> count = new Ref(0);
private TransactionFactory txFactory = new TransactionFactoryBuilder()
.setTimeout(new Duration(3, TimeUnit.SECONDS))
.setTimeout(new FiniteDuration(3, TimeUnit.SECONDS))
.build();
public UntypedCoordinatedCounter(String name) {

View file

@ -4,7 +4,7 @@ import akka.transactor.UntypedTransactor;
import akka.transactor.SendTo;
import akka.actor.ActorRef;
import akka.stm.*;
import akka.util.Duration;
import akka.util.FiniteDuration;
import org.multiverse.api.StmUtils;
@ -23,7 +23,7 @@ public class UntypedCounter extends UntypedTransactor {
@Override public TransactionFactory transactionFactory() {
return new TransactionFactoryBuilder()
.setTimeout(new Duration(3, TimeUnit.SECONDS))
.setTimeout(new FiniteDuration(3, TimeUnit.SECONDS))
.build();
}
@ -74,4 +74,4 @@ public class UntypedCounter extends UntypedTransactor {
return true;
} else return false;
}
}
}