diff --git a/akka-actor/src/main/scala/akka/actor/FSM.scala b/akka-actor/src/main/scala/akka/actor/FSM.scala old mode 100644 new mode 100755 index 17c0d1dd64..f84c35837a --- a/akka-actor/src/main/scala/akka/actor/FSM.scala +++ b/akka-actor/src/main/scala/akka/actor/FSM.scala @@ -1,13 +1,20 @@ /** * Copyright (C) 2009-2011 Scalable Solutions AB */ - package akka.actor +import akka.util._ + import scala.collection.mutable -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.ScheduledFuture object FSM { + + case class CurrentState[S](fsmRef: ActorRef, state: S) + case class Transition[S](fsmRef: ActorRef, from: S, to: S) + case class SubscribeTransitionCallBack(actorRef: ActorRef) + case class UnsubscribeTransitionCallBack(actorRef: ActorRef) + sealed trait Reason case object Normal extends Reason case object Shutdown extends Reason @@ -15,85 +22,288 @@ object FSM { case object StateTimeout case class TimeoutMarker(generation: Long) + + case class Timer(name: String, msg: AnyRef, repeat: Boolean) { + private var ref: Option[ScheduledFuture[AnyRef]] = _ + + def schedule(actor: ActorRef, timeout: Duration) { + if (repeat) { + ref = Some(Scheduler.schedule(actor, this, timeout.length, timeout.length, timeout.unit)) + } else { + ref = Some(Scheduler.scheduleOnce(actor, this, timeout.length, timeout.unit)) + } + } + + def cancel { + if (ref.isDefined) { + ref.get.cancel(true) + ref = None + } + } + } + + /* + * With these implicits in scope, you can write "5 seconds" anywhere a + * Duration or Option[Duration] is expected. This is conveniently true + * for derived classes. + */ + implicit def d2od(d: Duration): Option[Duration] = Some(d) } +/** + * Finite State Machine actor trait. Use as follows: + * + *
+ *   object A {
+ *     trait State
+ *     case class One extends State
+ *     case class Two extends State
+ *
+ *     case class Data(i : Int)
+ *   }
+ *
+ *   class A extends Actor with FSM[A.State, A.Data] {
+ *     import A._
+ *
+ *     startWith(One, Data(42))
+ *     when(One) {
+ *         case Event(SomeMsg, Data(x)) => ...
+ *         case Ev(SomeMsg) => ... // convenience when data not needed
+ *     }
+ *     when(Two, stateTimeout = 5 seconds) { ... }
+ *     initialize
+ *   }
+ * 
+ * + * Within the partial function the following values are returned for effecting + * state transitions: + * + * - stay for staying in the same state + * - stay using Data(...) for staying in the same state, but with + * different data + * - stay forMax 5.millis for staying with a state timeout; can be + * combined with using + * - goto(...) for changing into a different state; also supports + * using and forMax + * - stop for terminating this FSM actor + * + * Each of the above also supports the method replying(AnyRef) for + * sending a reply before changing state. + * + * Another feature is that other actors may subscribe for transition events by + * sending a SubscribeTransitionCallback message to this actor; + * use UnsubscribeTransitionCallback before stopping the other + * actor. + * + * State timeouts set an upper bound to the time which may pass before another + * message is received in the current state. If no external message is + * available, then upon expiry of the timeout a StateTimeout message is sent. + * Note that this message will only be received in the state for which the + * timeout was set and that any message received will cancel the timeout + * (possibly to be started again by the next transition). + * + * Another feature is the ability to install and cancel single-shot as well as + * repeated timers which arrange for the sending of a user-specified message: + * + *
+ *   setTimer("tock", TockMsg, 1 second, true) // repeating
+ *   setTimer("lifetime", TerminateMsg, 1 hour, false) // single-shot
+ *   cancelTimer("tock")
+ *   timerActive_? ("tock")
+ * 
+ */ trait FSM[S, D] { this: Actor => + import FSM._ - type StateFunction = scala.PartialFunction[Event, State] + type StateFunction = scala.PartialFunction[Event[D], State] + type Timeout = Option[Duration] + type TransitionHandler = (S, S) => Unit - /** DSL */ - protected final def notifying(transitionHandler: PartialFunction[Transition, Unit]) = { - transitionEvent = transitionHandler - } - - protected final def when(stateName: S)(stateFunction: StateFunction) = { - register(stateName, stateFunction) + /* DSL */ + + /** + * Insert a new StateFunction at the end of the processing chain for the + * given state. If the stateTimeout parameter is set, entering this state + * without a differing explicit timeout setting will trigger a StateTimeout + * event; the same is true when using #stay. + * + * @param stateName designator for the state + * @param stateTimeout default state timeout for this state + * @param stateFunction partial function describing response to input + */ + protected final def when(stateName: S, stateTimeout: Timeout = None)(stateFunction: StateFunction) = { + register(stateName, stateFunction, stateTimeout) } + /** + * Set initial state. Call this method from the constructor before the #initialize method. + * + * @param stateName initial state designator + * @param stateData initial state data + * @param timeout state timeout for the initial state, overriding the default timeout for that state + */ protected final def startWith(stateName: S, stateData: D, - timeout: Option[(Long, TimeUnit)] = None) = { - applyState(State(stateName, stateData, timeout)) + timeout: Timeout = None) = { + currentState = State(stateName, stateData, timeout) } + /** + * Produce transition to other state. Return this from a state function in + * order to effect the transition. + * + * @param nextStateName state designator for the next state + * @return state transition descriptor + */ protected final def goto(nextStateName: S): State = { State(nextStateName, currentState.stateData) } + /** + * Produce "empty" transition descriptor. Return this from a state function + * when no state change is to be effected. + * + * @return descriptor for staying in current state + */ protected final def stay(): State = { + // cannot directly use currentState because of the timeout field goto(currentState.stateName) } + /** + * Produce change descriptor to stop this FSM actor with reason "Normal". + */ protected final def stop(): State = { stop(Normal) } + /** + * Produce change descriptor to stop this FSM actor including specified reason. + */ protected final def stop(reason: Reason): State = { stop(reason, currentState.stateData) } + /** + * Produce change descriptor to stop this FSM actor including specified reason. + */ protected final def stop(reason: Reason, stateData: D): State = { - stay using stateData withStopReason(reason) + stay using stateData withStopReason (reason) } - def whenUnhandled(stateFunction: StateFunction) = { - handleEvent = stateFunction + /** + * Schedule named timer to deliver message after given delay, possibly repeating. + * @param name identifier to be used with cancelTimer() + * @param msg message to be delivered + * @param timeout delay of first message delivery and between subsequent messages + * @param repeat send once if false, scheduleAtFixedRate if true + * @return current state descriptor + */ + protected final def setTimer(name: String, msg: AnyRef, timeout: Duration, repeat: Boolean): State = { + if (timers contains name) { + timers(name).cancel + } + val timer = Timer(name, msg, repeat) + timer.schedule(self, timeout) + timers(name) = timer + stay } - def onTermination(terminationHandler: PartialFunction[Reason, Unit]) = { + /** + * Cancel named timer, ensuring that the message is not subsequently delivered (no race). + * @param name of the timer to cancel + */ + protected final def cancelTimer(name: String) = { + if (timers contains name) { + timers(name).cancel + timers -= name + } + } + + /** + * Inquire whether the named timer is still active. Returns true unless the + * timer does not exist, has previously been canceled or if it was a + * single-shot timer whose message was already received. + */ + protected final def timerActive_?(name: String) = timers contains name + + /** + * Set state timeout explicitly. This method can safely be used from within a + * state handler. + */ + protected final def setStateTimeout(state : S, timeout : Timeout) { + stateTimeouts(state) = timeout + } + + /** + * Set handler which is called upon each state transition, i.e. not when + * staying in the same state. + */ + protected final def onTransition(transitionHandler: TransitionHandler) = { + transitionEvent = transitionHandler + } + + /** + * Set handler which is called upon termination of this FSM actor. + */ + protected final def onTermination(terminationHandler: PartialFunction[StopEvent[S,D], Unit]) = { terminateEvent = terminationHandler } - /** FSM State data and default handlers */ + /** + * Set handler which is called upon reception of unhandled messages. + */ + protected final def whenUnhandled(stateFunction: StateFunction) = { + handleEvent = stateFunction orElse handleEventDefault + } + + /** + * Verify existence of initial state and setup timers. This should be the + * last call within the constructor. + */ + def initialize { + makeTransition(currentState) + } + + /**FSM State data and default handlers */ private var currentState: State = _ private var timeoutFuture: Option[ScheduledFuture[AnyRef]] = None private var generation: Long = 0L + private var transitionCallBackList: List[ActorRef] = Nil + + private val timers = mutable.Map[String, Timer]() private val stateFunctions = mutable.Map[S, StateFunction]() - private def register(name: S, function: StateFunction) { + private val stateTimeouts = mutable.Map[S, Timeout]() + + private def register(name: S, function: StateFunction, timeout: Timeout) { if (stateFunctions contains name) { stateFunctions(name) = stateFunctions(name) orElse function + stateTimeouts(name) = timeout orElse stateTimeouts(name) } else { stateFunctions(name) = function + stateTimeouts(name) = timeout } } - private var handleEvent: StateFunction = { + private val handleEventDefault: StateFunction = { case Event(value, stateData) => log.slf4j.warn("Event {} not handled in state {}, staying at current state", value, currentState.stateName) stay } + private var handleEvent: StateFunction = handleEventDefault - private var terminateEvent: PartialFunction[Reason, Unit] = { - case failure@Failure(_) => log.slf4j.error("Stopping because of a {}", failure) - case reason => log.slf4j.info("Stopping because of reason: {}", reason) + private var terminateEvent: PartialFunction[StopEvent[S,D], Unit] = { + case StopEvent(Failure(cause), _, _) => + log.slf4j.error("Stopping because of a failure with cause {}", cause) + case StopEvent(reason, _, _) => log.slf4j.info("Stopping because of reason: {}", reason) } - private var transitionEvent: PartialFunction[Transition, Unit] = { - case Transition(from, to) => log.slf4j.debug("Transitioning from state {} to {}", from, to) + private var transitionEvent: TransitionHandler = (from, to) => { + log.slf4j.debug("Transitioning from state {} to {}", from, to) } override final protected def receive: Receive = { @@ -101,8 +311,24 @@ trait FSM[S, D] { if (generation == gen) { processEvent(StateTimeout) } + case t@Timer(name, msg, repeat) => + if (timerActive_?(name)) { + processEvent(msg) + if (!repeat) { + timers -= name + } + } + case SubscribeTransitionCallBack(actorRef) => + // send current state back as reference point + actorRef ! CurrentState(self, currentState.stateName) + transitionCallBackList ::= actorRef + case UnsubscribeTransitionCallBack(actorRef) => + transitionCallBackList = transitionCallBackList.filterNot(_ == actorRef) case value => { - timeoutFuture = timeoutFuture.flatMap {ref => ref.cancel(true); None} + if (timeoutFuture.isDefined) { + timeoutFuture.get.cancel(true) + timeoutFuture = None + } generation += 1 processEvent(value) } @@ -110,7 +336,13 @@ trait FSM[S, D] { private def processEvent(value: Any) = { val event = Event(value, currentState.stateData) - val nextState = (stateFunctions(currentState.stateName) orElse handleEvent).apply(event) + val stateFunc = stateFunctions(currentState.stateName) + val nextState = if (stateFunc isDefinedAt event) { + stateFunc(event) + } else { + // handleEventDefault ensures that this is always defined + handleEvent(event) + } nextState.stopReason match { case Some(reason) => terminate(reason) case None => makeTransition(nextState) @@ -122,7 +354,11 @@ trait FSM[S, D] { terminate(Failure("Next state %s does not exist".format(nextState.stateName))) } else { if (currentState.stateName != nextState.stateName) { - transitionEvent.apply(Transition(currentState.stateName, nextState.stateName)) + transitionEvent.apply(currentState.stateName, nextState.stateName) + if (!transitionCallBackList.isEmpty) { + val transition = Transition(self, currentState.stateName, nextState.stateName) + transitionCallBackList.foreach(_ ! transition) + } } applyState(nextState) } @@ -130,26 +366,43 @@ trait FSM[S, D] { private def applyState(nextState: State) = { currentState = nextState - currentState.timeout.foreach { t => - timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t._1, t._2)) + val timeout = if (currentState.timeout.isDefined) currentState.timeout else stateTimeouts(currentState.stateName) + if (timeout.isDefined) { + val t = timeout.get + if (t.length >= 0) { + timeoutFuture = Some(Scheduler.scheduleOnce(self, TimeoutMarker(generation), t.length, t.unit)) + } } } private def terminate(reason: Reason) = { - terminateEvent.apply(reason) + timers.foreach{ case (timer, t) => log.slf4j.info("Canceling timer {}", timer); t.cancel} + terminateEvent.apply(StopEvent(reason, currentState.stateName, currentState.stateData)) self.stop } + case class Event[D](event: Any, stateData: D) + object Ev { + def unapply[D](e : Event[D]) : Option[Any] = Some(e.event) + } - case class Event(event: Any, stateData: D) + case class State(stateName: S, stateData: D, timeout: Timeout = None) { - case class State(stateName: S, stateData: D, timeout: Option[(Long, TimeUnit)] = None) { - - def forMax(timeout: (Long, TimeUnit)): State = { + /** + * Modify state transition descriptor to include a state timeout for the + * next state. This timeout overrides any default timeout set for the next + * state. + */ + def forMax(timeout: Duration): State = { copy(timeout = Some(timeout)) } - def replying(replyValue:Any): State = { + /** + * Send reply to sender of the current message, if available. + * + * @return this state transition descriptor + */ + def replying(replyValue: Any): State = { self.sender match { case Some(sender) => sender ! replyValue case None => log.slf4j.error("Unable to send reply value {}, no sender reference to reply to", replyValue) @@ -157,16 +410,21 @@ trait FSM[S, D] { this } + /** + * Modify state transition descriptor with new state data. The data will be + * set when transitioning to the new state. + */ def using(nextStateDate: D): State = { copy(stateData = nextStateDate) } private[akka] var stopReason: Option[Reason] = None + private[akka] def withStopReason(reason: Reason): State = { stopReason = Some(reason) this } } - case class Transition(from: S, to: S) + case class StopEvent[S, D](reason: Reason, currentState: S, stateData: D) } diff --git a/akka-actor/src/main/scala/akka/util/Duration.scala b/akka-actor/src/main/scala/akka/util/Duration.scala index 1ab14164fe..743ce0fc4c 100644 --- a/akka-actor/src/main/scala/akka/util/Duration.scala +++ b/akka-actor/src/main/scala/akka/util/Duration.scala @@ -5,17 +5,149 @@ package akka.util import java.util.concurrent.TimeUnit +import TimeUnit._ +import java.lang.{Long => JLong, Double => JDouble} object Duration { - def apply(length: Long, unit: TimeUnit) = new Duration(length, unit) - def apply(length: Long, unit: String) = new Duration(length, timeUnit(unit)) + def apply(length: Long, unit: TimeUnit) : Duration = new FiniteDuration(length, unit) + def apply(length: Double, unit: TimeUnit) : Duration = fromNanos(unit.toNanos(1) * length) + def apply(length: Long, unit: String) : Duration = new FiniteDuration(length, timeUnit(unit)) - def timeUnit(unit: String) = unit.toLowerCase match { - case "nanoseconds" | "nanos" | "nanosecond" | "nano" => TimeUnit.NANOSECONDS - case "microseconds" | "micros" | "microsecond" | "micro" => TimeUnit.MICROSECONDS - case "milliseconds" | "millis" | "millisecond" | "milli" => TimeUnit.MILLISECONDS - case _ => TimeUnit.SECONDS + def fromNanos(nanos : Long) : Duration = { + if (nanos % 86400000000000L == 0) { + Duration(nanos / 86400000000000L, DAYS) + } else if (nanos % 3600000000000L == 0) { + Duration(nanos / 3600000000000L, HOURS) + } else if (nanos % 60000000000L == 0) { + Duration(nanos / 60000000000L, MINUTES) + } else if (nanos % 1000000000L == 0) { + Duration(nanos / 1000000000L, SECONDS) + } else if (nanos % 1000000L == 0) { + Duration(nanos / 1000000L, MILLISECONDS) + } else if (nanos % 1000L == 0) { + Duration(nanos / 1000L, MICROSECONDS) + } else { + Duration(nanos, NANOSECONDS) + } } + + def fromNanos(nanos : Double) : Duration = fromNanos((nanos + 0.5).asInstanceOf[Long]) + + /** + * Construct a Duration by parsing a String. In case of a format error, a + * RuntimeException is thrown. See `unapply(String)` for more information. + */ + def apply(s : String) : Duration = unapply(s) getOrElse error("format error") + + /** + * Deconstruct a Duration into length and unit if it is finite. + */ + def unapply(d : Duration) : Option[(Long, TimeUnit)] = { + if (d.finite_?) { + Some((d.length, d.unit)) + } else { + None + } + } + + private val RE = ("""^\s*(\d+(?:\.\d+)?)\s*"""+ // length part + "(?:"+ // units are distinguished in separate match groups + "(d|day|days)|"+ + "(h|hour|hours)|"+ + "(min|minute|minutes)|"+ + "(s|sec|second|seconds)|"+ + "(ms|milli|millis|millisecond|milliseconds)|"+ + "(µs|micro|micros|microsecond|microseconds)|"+ + "(ns|nano|nanos|nanosecond|nanoseconds)"+ + """)\s*$""").r // close the non-capturing group + private val REinf = """^\s*Inf\s*$""".r + private val REminf = """^\s*(?:-\s*|Minus)Inf\s*""".r + + /** + * Parse String, return None if no match. Format is `""`, where + * whitespace is allowed before, between and after the parts. Infinities are + * designated by `"Inf"` and `"-Inf"` or `"MinusInf"`. + */ + def unapply(s : String) : Option[Duration] = s match { + case RE(length, d, h, m, s, ms, mus, ns) => + if ( d ne null) Some(Duration(JDouble.parseDouble(length), DAYS)) else + if ( h ne null) Some(Duration(JDouble.parseDouble(length), HOURS)) else + if ( m ne null) Some(Duration(JDouble.parseDouble(length), MINUTES)) else + if ( s ne null) Some(Duration(JDouble.parseDouble(length), SECONDS)) else + if ( ms ne null) Some(Duration(JDouble.parseDouble(length), MILLISECONDS)) else + if (mus ne null) Some(Duration(JDouble.parseDouble(length), MICROSECONDS)) else + if ( ns ne null) Some(Duration(JDouble.parseDouble(length), NANOSECONDS)) else + error("made some error in regex (should not be possible)") + case REinf() => Some(Inf) + case REminf() => Some(MinusInf) + case _ => None + } + + /** + * Parse TimeUnit from string representation. + */ + def timeUnit(unit: String) = unit.toLowerCase match { + case "d" | "day" | "days" => DAYS + case "h" | "hour" | "hours" => HOURS + case "min" | "minute" | "minutes" => MINUTES + case "s" | "sec" | "second" | "seconds" => SECONDS + case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS + case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS + case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS + } + + trait Infinite { + this : Duration => + + override def equals(other : Any) = false + + def +(other : Duration) : Duration = this + def -(other : Duration) : Duration = this + def *(other : Double) : Duration = this + def /(other : Double) : Duration = this + + def finite_? = false + + def length : Long = throw new IllegalArgumentException("length not allowed on infinite Durations") + def unit : TimeUnit = throw new IllegalArgumentException("unit not allowed on infinite Durations") + def toNanos : Long = throw new IllegalArgumentException("toNanos not allowed on infinite Durations") + def toMicros : Long = throw new IllegalArgumentException("toMicros not allowed on infinite Durations") + def toMillis : Long = throw new IllegalArgumentException("toMillis not allowed on infinite Durations") + def toSeconds : Long = throw new IllegalArgumentException("toSeconds not allowed on infinite Durations") + def toMinutes : Long = throw new IllegalArgumentException("toMinutes not allowed on infinite Durations") + def toHours : Long = throw new IllegalArgumentException("toHours not allowed on infinite Durations") + def toDays : Long = throw new IllegalArgumentException("toDays not allowed on infinite Durations") + def toUnit(unit : TimeUnit) : Double = throw new IllegalArgumentException("toUnit not allowed on infinite Durations") + + def printHMS = toString + } + + /** + * Infinite duration: greater than any other and not equal to any other, + * including itself. + */ + object Inf extends Duration with Infinite { + override def toString = "Duration.Inf" + def >(other : Duration) = true + def >=(other : Duration) = true + def <(other : Duration) = false + def <=(other : Duration) = false + def unary_- : Duration = MinusInf + } + + /** + * Infinite negative duration: lesser than any other and not equal to any other, + * including itself. + */ + object MinusInf extends Duration with Infinite { + override def toString = "Duration.MinusInf" + def >(other : Duration) = false + def >=(other : Duration) = false + def <(other : Duration) = true + def <=(other : Duration) = true + def unary_- : Duration = Inf + } + } /** @@ -24,11 +156,11 @@ object Duration { *

* Examples of usage from Java: *

- * import akka.util.Duration;
+ * import akka.util.FiniteDuration;
  * import java.util.concurrent.TimeUnit;
  *
- * Duration duration = new Duration(100, TimeUnit.MILLISECONDS);
- * Duration duration = new Duration(5, "seconds");
+ * Duration duration = new FiniteDuration(100, MILLISECONDS);
+ * Duration duration = new FiniteDuration(5, "seconds");
  *
  * duration.toNanos();
  * 
@@ -39,70 +171,255 @@ object Duration { * import akka.util.Duration * import java.util.concurrent.TimeUnit * - * val duration = Duration(100, TimeUnit.MILLISECONDS) + * val duration = Duration(100, MILLISECONDS) * val duration = Duration(100, "millis") * * duration.toNanos + * duration < 1.second + * duration <= Duration.Inf * * *

- * Implicits are also provided for Int and Long. Example usage: + * Implicits are also provided for Int, Long and Double. Example usage: *

  * import akka.util.duration._
  *
- * val duration = 100.millis
+ * val duration = 100 millis
+ * 
+ * + * Extractors, parsing and arithmetic are also included: + *
+ * val d = Duration("1.2 µs")
+ * val Duration(length, unit) = 5 millis
+ * val d2 = d * 2.5
+ * val d3 = d2 + 1.millisecond
  * 
*/ -class Duration(val length: Long, val unit: TimeUnit) { +trait Duration { + def length : Long + def unit : TimeUnit + def toNanos : Long + def toMicros : Long + def toMillis : Long + def toSeconds : Long + def toMinutes : Long + def toHours : Long + def toDays : Long + def toUnit(unit : TimeUnit) : Double + def printHMS : String + def <(other : Duration) : Boolean + def <=(other : Duration) : Boolean + def >(other : Duration) : Boolean + def >=(other : Duration) : Boolean + def +(other : Duration) : Duration + def -(other : Duration) : Duration + def *(factor : Double) : Duration + def /(factor : Double) : Duration + def unary_- : Duration + def finite_? : Boolean +} + +class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration { + import Duration._ + def this(length: Long, unit: String) = this(length, Duration.timeUnit(unit)) + def toNanos = unit.toNanos(length) def toMicros = unit.toMicros(length) def toMillis = unit.toMillis(length) def toSeconds = unit.toSeconds(length) - override def toString = "Duration(" + length + ", " + unit + ")" + def toMinutes = unit.toMinutes(length) + def toHours = unit.toHours(length) + def toDays = unit.toDays(length) + def toUnit(u : TimeUnit) = long2double(toNanos) / NANOSECONDS.convert(1, u) + + override def toString = this match { + case Duration(1, DAYS) => "1 day" + case Duration(x, DAYS) => x+" days" + case Duration(1, HOURS) => "1 hour" + case Duration(x, HOURS) => x+" hours" + case Duration(1, MINUTES) => "1 minute" + case Duration(x, MINUTES) => x+" minutes" + case Duration(1, SECONDS) => "1 second" + case Duration(x, SECONDS) => x+" seconds" + case Duration(1, MILLISECONDS) => "1 millisecond" + case Duration(x, MILLISECONDS) => x+" milliseconds" + case Duration(1, MICROSECONDS) => "1 microsecond" + case Duration(x, MICROSECONDS) => x+" microseconds" + case Duration(1, NANOSECONDS) => "1 nanosecond" + case Duration(x, NANOSECONDS) => x+" nanoseconds" + } + + def printHMS = "%02d:%02d:%06.3f".format(toHours, toMinutes % 60, toMillis / 1000. % 60) + + def <(other : Duration) = { + if (other.finite_?) { + toNanos < other.asInstanceOf[FiniteDuration].toNanos + } else { + other > this + } + } + + def <=(other : Duration) = { + if (other.finite_?) { + toNanos <= other.asInstanceOf[FiniteDuration].toNanos + } else { + other >= this + } + } + + def >(other : Duration) = { + if (other.finite_?) { + toNanos > other.asInstanceOf[FiniteDuration].toNanos + } else { + other < this + } + } + + def >=(other : Duration) = { + if (other.finite_?) { + toNanos >= other.asInstanceOf[FiniteDuration].toNanos + } else { + other <= this + } + } + + def +(other : Duration) = { + if (!other.finite_?) { + other + } else { + val nanos = toNanos + other.asInstanceOf[FiniteDuration].toNanos + fromNanos(nanos) + } + } + + def -(other : Duration) = { + if (!other.finite_?) { + other + } else { + val nanos = toNanos - other.asInstanceOf[FiniteDuration].toNanos + fromNanos(nanos) + } + } + + def *(factor : Double) = fromNanos(long2double(toNanos) * factor) + + def /(factor : Double) = fromNanos(long2double(toNanos) / factor) + + def unary_- = Duration(-length, unit) + + def finite_? = true + + override def equals(other : Any) = + other.isInstanceOf[FiniteDuration] && + toNanos == other.asInstanceOf[FiniteDuration].toNanos + + override def hashCode = toNanos.asInstanceOf[Int] } package object duration { implicit def intToDurationInt(n: Int) = new DurationInt(n) implicit def longToDurationLong(n: Long) = new DurationLong(n) + implicit def doubleToDurationDouble(d: Double) = new DurationDouble(d) + + implicit def pairIntToDuration(p : (Int, TimeUnit)) = Duration(p._1, p._2) + implicit def pairLongToDuration(p : (Long, TimeUnit)) = Duration(p._1, p._2) + implicit def durationToPair(d : Duration) = (d.length, d.unit) + + implicit def intMult(i : Int) = new { + def *(d : Duration) = d * i + } + implicit def longMult(l : Long) = new { + def *(d : Duration) = d * l + } + implicit def doubleMult(f : Double) = new { + def *(d : Duration) = d * f + } } class DurationInt(n: Int) { - def nanoseconds = Duration(n, TimeUnit.NANOSECONDS) - def nanos = Duration(n, TimeUnit.NANOSECONDS) - def nanosecond = Duration(n, TimeUnit.NANOSECONDS) - def nano = Duration(n, TimeUnit.NANOSECONDS) + def nanoseconds = Duration(n, NANOSECONDS) + def nanos = Duration(n, NANOSECONDS) + def nanosecond = Duration(n, NANOSECONDS) + def nano = Duration(n, NANOSECONDS) - def microseconds = Duration(n, TimeUnit.MICROSECONDS) - def micros = Duration(n, TimeUnit.MICROSECONDS) - def microsecond = Duration(n, TimeUnit.MICROSECONDS) - def micro = Duration(n, TimeUnit.MICROSECONDS) + def microseconds = Duration(n, MICROSECONDS) + def micros = Duration(n, MICROSECONDS) + def microsecond = Duration(n, MICROSECONDS) + def micro = Duration(n, MICROSECONDS) - def milliseconds = Duration(n, TimeUnit.MILLISECONDS) - def millis = Duration(n, TimeUnit.MILLISECONDS) - def millisecond = Duration(n, TimeUnit.MILLISECONDS) - def milli = Duration(n, TimeUnit.MILLISECONDS) + def milliseconds = Duration(n, MILLISECONDS) + def millis = Duration(n, MILLISECONDS) + def millisecond = Duration(n, MILLISECONDS) + def milli = Duration(n, MILLISECONDS) - def seconds = Duration(n, TimeUnit.SECONDS) - def second = Duration(n, TimeUnit.SECONDS) + def seconds = Duration(n, SECONDS) + def second = Duration(n, SECONDS) + + def minutes = Duration(n, MINUTES) + def minute = Duration(n, MINUTES) + + def hours = Duration(n, HOURS) + def hour = Duration(n, HOURS) + + def days = Duration(n, DAYS) + def day = Duration(n, DAYS) } class DurationLong(n: Long) { - def nanoseconds = Duration(n, TimeUnit.NANOSECONDS) - def nanos = Duration(n, TimeUnit.NANOSECONDS) - def nanosecond = Duration(n, TimeUnit.NANOSECONDS) - def nano = Duration(n, TimeUnit.NANOSECONDS) + def nanoseconds = Duration(n, NANOSECONDS) + def nanos = Duration(n, NANOSECONDS) + def nanosecond = Duration(n, NANOSECONDS) + def nano = Duration(n, NANOSECONDS) - def microseconds = Duration(n, TimeUnit.MICROSECONDS) - def micros = Duration(n, TimeUnit.MICROSECONDS) - def microsecond = Duration(n, TimeUnit.MICROSECONDS) - def micro = Duration(n, TimeUnit.MICROSECONDS) + def microseconds = Duration(n, MICROSECONDS) + def micros = Duration(n, MICROSECONDS) + def microsecond = Duration(n, MICROSECONDS) + def micro = Duration(n, MICROSECONDS) - def milliseconds = Duration(n, TimeUnit.MILLISECONDS) - def millis = Duration(n, TimeUnit.MILLISECONDS) - def millisecond = Duration(n, TimeUnit.MILLISECONDS) - def milli = Duration(n, TimeUnit.MILLISECONDS) + def milliseconds = Duration(n, MILLISECONDS) + def millis = Duration(n, MILLISECONDS) + def millisecond = Duration(n, MILLISECONDS) + def milli = Duration(n, MILLISECONDS) - def seconds = Duration(n, TimeUnit.SECONDS) - def second = Duration(n, TimeUnit.SECONDS) + def seconds = Duration(n, SECONDS) + def second = Duration(n, SECONDS) + + def minutes = Duration(n, MINUTES) + def minute = Duration(n, MINUTES) + + def hours = Duration(n, HOURS) + def hour = Duration(n, HOURS) + + def days = Duration(n, DAYS) + def day = Duration(n, DAYS) +} + +class DurationDouble(d: Double) { + def nanoseconds = Duration(d, NANOSECONDS) + def nanos = Duration(d, NANOSECONDS) + def nanosecond = Duration(d, NANOSECONDS) + def nano = Duration(d, NANOSECONDS) + + def microseconds = Duration(d, MICROSECONDS) + def micros = Duration(d, MICROSECONDS) + def microsecond = Duration(d, MICROSECONDS) + def micro = Duration(d, MICROSECONDS) + + def milliseconds = Duration(d, MILLISECONDS) + def millis = Duration(d, MILLISECONDS) + def millisecond = Duration(d, MILLISECONDS) + def milli = Duration(d, MILLISECONDS) + + def seconds = Duration(d, SECONDS) + def second = Duration(d, SECONDS) + + def minutes = Duration(d, MINUTES) + def minute = Duration(d, MINUTES) + + def hours = Duration(d, HOURS) + def hour = Duration(d, HOURS) + + def days = Duration(d, DAYS) + def day = Duration(d, DAYS) } diff --git a/akka-actor/src/main/scala/akka/util/TestKit.scala b/akka-actor/src/main/scala/akka/util/TestKit.scala new file mode 100644 index 0000000000..bb400ff992 --- /dev/null +++ b/akka-actor/src/main/scala/akka/util/TestKit.scala @@ -0,0 +1,434 @@ +package akka.util + +import akka.actor.{Actor, FSM} +import Actor._ +import duration._ + +import java.util.concurrent.{BlockingDeque, LinkedBlockingDeque, TimeUnit} + +import scala.annotation.tailrec + +object TestActor { + type Ignore = Option[PartialFunction[AnyRef, Boolean]] + + case class SetTimeout(d : Duration) + case class SetIgnore(i : Ignore) +} + +class TestActor(queue : BlockingDeque[AnyRef]) extends Actor with FSM[Int, TestActor.Ignore] { + import FSM._ + import TestActor._ + + startWith(0, None) + when(0, stateTimeout = 5 seconds) { + case Ev(SetTimeout(d)) => + setStateTimeout(0, if (d.finite_?) d else None) + stay + case Ev(SetIgnore(ign)) => stay using ign + case Ev(StateTimeout) => + stop + case Event(x : AnyRef, ign) => + val ignore = ign map (z => if (z isDefinedAt x) z(x) else false) getOrElse false + if (!ignore) { + queue.offerLast(x) + } + stay + } + initialize +} + +/** + * Test kit for testing actors. Inheriting from this trait enables reception of + * replies from actors, which are queued by an internal actor and can be + * examined using the `expectMsg...` methods. Assertions and bounds concerning + * timing are available in the form of `within` blocks. + * + *
+ * class Test extends TestKit {
+ *     val test = actorOf[SomeActor].start
+ *
+ *     within (1 second) {
+ *       test ! SomeWork
+ *       expectMsg(Result1) // bounded to 1 second
+ *       expectMsg(Result2) // bounded to the remainder of the 1 second
+ *     }
+ * }
+ * 
+ * + * Beware of two points: + * + * - the internal test actor needs to be stopped, either explicitly using + * `stopTestActor` or implicitly by using its internal inactivity timeout, + * see `setTestActorTimeout` + * - this trait is not thread-safe (only one actor with one queue, one stack + * of `within` blocks); it is expected that the code is executed from a + * constructor as shown above, which makes this a non-issue, otherwise take + * care not to run tests within a single test class instance in parallel. + * + * @author Roland Kuhn + * @since 1.1 + */ +trait TestKit { + + private val queue = new LinkedBlockingDeque[AnyRef]() + + /** + * ActorRef of the test actor. Access is provided to enable e.g. + * registration as message target. + */ + protected val testActor = actorOf(new TestActor(queue)).start + + /** + * Implicit sender reference so that replies are possible for messages sent + * from the test class. + */ + protected implicit val senderOption = Some(testActor) + + private var end : Duration = Duration.Inf + /* + * THIS IS A HACK: expectNoMsg and receiveWhile are bounded by `end`, but + * running them should not trigger an AssertionError, so mark their end + * time here and do not fail at the end of `within` if that time is not + * long gone. + */ + private var lastSoftTimeout : Duration = now - 5.millis + + /** + * Stop test actor. Should be done at the end of the test unless relying on + * test actor timeout. + */ + def stopTestActor { testActor.stop } + + /** + * Set test actor timeout. By default, the test actor shuts itself down + * after 5 seconds of inactivity. Set this to Duration.Inf to disable this + * behavior, but make sure that someone will then call `stopTestActor`, + * unless you want to leak actors, e.g. wrap test in + * + *
+   *   try {
+   *     ...
+   *   } finally { stopTestActor }
+   * 
+ */ + def setTestActorTimeout(d : Duration) { testActor ! TestActor.SetTimeout(d) } + + /** + * Ignore all messages in the test actor for which the given partial + * function returns true. + */ + def ignoreMsg(f : PartialFunction[AnyRef, Boolean]) { testActor ! TestActor.SetIgnore(Some(f)) } + + /** + * Stop ignoring messages in the test actor. + */ + def ignoreNoMsg { testActor ! TestActor.SetIgnore(None) } + + /** + * Obtain current time (`System.currentTimeMillis`) as Duration. + */ + def now : Duration = System.nanoTime.nanos + + /** + * Obtain time remaining for execution of the innermost enclosing `within` block. + */ + def remaining : Duration = end - now + + /** + * Execute code block while bounding its execution time between `min` and + * `max`. `within` blocks may be nested. All methods in this trait which + * take maximum wait times are available in a version which implicitly uses + * the remaining time governed by the innermost enclosing `within` block. + * + *
+   * val ret = within(50 millis) {
+   *         test ! "ping"
+   *         expectMsgClass(classOf[String])
+   *       }
+   * 
+ */ + def within[T](min : Duration, max : Duration)(f : => T) : T = { + val start = now + val rem = end - start + assert (rem >= min, "required min time "+min+" not possible, only "+format(min.unit, rem)+" left") + + val max_diff = if (max < rem) max else rem + val prev_end = end + end = start + max_diff + + val ret = f + + val diff = now - start + assert (min <= diff, "block took "+format(min.unit, diff)+", should at least have been "+min) + /* + * caution: HACK AHEAD + */ + if (now - lastSoftTimeout > 5.millis) { + assert (diff <= max_diff, "block took "+format(max.unit, diff)+", exceeding "+format(max.unit, max_diff)) + } else { + lastSoftTimeout -= 5.millis + } + + end = prev_end + ret + } + + /** + * Same as calling `within(0 seconds, max)(f)`. + */ + def within[T](max : Duration)(f : => T) : T = within(0 seconds, max)(f) + + /** + * Same as `expectMsg`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsg(obj : Any) : AnyRef = expectMsg(remaining, obj) + + /** + * Receive one message from the test actor and assert that it equals the + * given object. Wait time is bounded by the given duration, with an + * AssertionFailure being thrown in case of timeout. + * + * @return the received object + */ + def expectMsg(max : Duration, obj : Any) : AnyRef = { + val o = receiveOne(max) + assert (o ne null, "timeout during expectMsg") + assert (obj == o, "expected "+obj+", found "+o) + o + } + + /** + * Same as `expectMsg`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsg[T](f : PartialFunction[Any, T]) : T = expectMsg(remaining)(f) + + /** + * Receive one message from the test actor and assert that the given + * partial function accepts it. Wait time is bounded by the given duration, + * with an AssertionFailure being thrown in case of timeout. + * + * Use this variant to implement more complicated or conditional + * processing. + * + * @return the received object as transformed by the partial function + */ + def expectMsg[T](max : Duration)(f : PartialFunction[Any, T]) : T = { + val o = receiveOne(max) + assert (o ne null, "timeout during expectMsg") + assert (f.isDefinedAt(o), "does not match: "+o) + f(o) + } + + /** + * Same as `expectMsgClass`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgClass[C](c : Class[C]) : C = expectMsgClass(remaining, c) + + /** + * Receive one message from the test actor and assert that it conforms to + * the given class. Wait time is bounded by the given duration, with an + * AssertionFailure being thrown in case of timeout. + * + * @return the received object + */ + def expectMsgClass[C](max : Duration, c : Class[C]) : C = { + val o = receiveOne(max) + assert (o ne null, "timeout during expectMsgClass") + assert (c isInstance o, "expected "+c+", found "+o.getClass) + o.asInstanceOf[C] + } + + /** + * Same as `expectMsgAnyOf`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgAnyOf(obj : Any*) : AnyRef = expectMsgAnyOf(remaining, obj : _*) + + /** + * Receive one message from the test actor and assert that it equals one of + * the given objects. Wait time is bounded by the given duration, with an + * AssertionFailure being thrown in case of timeout. + * + * @return the received object + */ + def expectMsgAnyOf(max : Duration, obj : Any*) : AnyRef = { + val o = receiveOne(max) + assert (o ne null, "timeout during expectMsgAnyOf") + assert (obj exists (_ == o), "found unexpected "+o) + o + } + + /** + * Same as `expectMsgAnyClassOf`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgAnyClassOf(obj : Class[_]*) : AnyRef = expectMsgAnyClassOf(remaining, obj : _*) + + /** + * Receive one message from the test actor and assert that it conforms to + * one of the given classes. Wait time is bounded by the given duration, + * with an AssertionFailure being thrown in case of timeout. + * + * @return the received object + */ + def expectMsgAnyClassOf(max : Duration, obj : Class[_]*) : AnyRef = { + val o = receiveOne(max) + assert (o ne null, "timeout during expectMsgAnyClassOf") + assert (obj exists (_ isInstance o), "found unexpected "+o) + o + } + + /** + * Same as `expectMsgAllOf`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgAllOf(obj : Any*) { expectMsgAllOf(remaining, obj : _*) } + + /** + * Receive a number of messages from the test actor matching the given + * number of objects and assert that for each given object one is received + * which equals it. This construct is useful when the order in which the + * objects are received is not fixed. Wait time is bounded by the given + * duration, with an AssertionFailure being thrown in case of timeout. + * + *
+   * within(1 second) {
+   *   dispatcher ! SomeWork1()
+   *   dispatcher ! SomeWork2()
+   *   expectMsgAllOf(Result1(), Result2())
+   * }
+   * 
+ */ + def expectMsgAllOf(max : Duration, obj : Any*) { + val recv = receiveN(obj.size, now + max) + assert (obj forall (x => recv exists (x == _)), "not found all") + } + + /** + * Same as `expectMsgAllClassOf`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgAllClassOf(obj : Class[_]*) { expectMsgAllClassOf(remaining, obj : _*) } + + /** + * Receive a number of messages from the test actor matching the given + * number of classes and assert that for each given class one is received + * which is of that class (equality, not conformance). This construct is + * useful when the order in which the objects are received is not fixed. + * Wait time is bounded by the given duration, with an AssertionFailure + * being thrown in case of timeout. + */ + def expectMsgAllClassOf(max : Duration, obj : Class[_]*) { + val recv = receiveN(obj.size, now + max) + assert (obj forall (x => recv exists (_.getClass eq x)), "not found all") + } + + /** + * Same as `expectMsgAllConformingOf`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectMsgAllConformingOf(obj : Class[_]*) { expectMsgAllClassOf(remaining, obj : _*) } + + /** + * Receive a number of messages from the test actor matching the given + * number of classes and assert that for each given class one is received + * which conforms to that class. This construct is useful when the order in + * which the objects are received is not fixed. Wait time is bounded by + * the given duration, with an AssertionFailure being thrown in case of + * timeout. + * + * Beware that one object may satisfy all given class constraints, which + * may be counter-intuitive. + */ + def expectMsgAllConformingOf(max : Duration, obj : Class[_]*) { + val recv = receiveN(obj.size, now + max) + assert (obj forall (x => recv exists (x isInstance _)), "not found all") + } + + /** + * Same as `expectNoMsg`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def expectNoMsg { expectNoMsg(remaining) } + + /** + * Assert that no message is received for the specified time. + */ + def expectNoMsg(max : Duration) { + val o = receiveOne(max) + assert (o eq null, "received unexpected message "+o) + lastSoftTimeout = now + } + + /** + * Same as `receiveWhile`, but takes the maximum wait time from the innermost + * enclosing `within` block. + */ + def receiveWhile[T](f : PartialFunction[AnyRef, T]) : Seq[T] = receiveWhile(remaining)(f) + + /** + * Receive a series of messages as long as the given partial function + * accepts them or the idle timeout is met or the overall maximum duration + * is elapsed. Returns the sequence of messages. + * + * Beware that the maximum duration is not implicitly bounded by or taken + * from the innermost enclosing `within` block, as it is not an error to + * hit the `max` duration in this case. + * + * One possible use of this method is for testing whether messages of + * certain characteristics are generated at a certain rate: + * + *
+   * test ! ScheduleTicks(100 millis)
+   * val series = receiveWhile(750 millis) {
+   *     case Tick(count) => count
+   * }
+   * assert(series == (1 to 7).toList)
+   * 
+ */ + def receiveWhile[T](max : Duration)(f : PartialFunction[AnyRef, T]) : Seq[T] = { + val stop = now + max + + @tailrec def doit(acc : List[T]) : List[T] = { + receiveOne(stop - now) match { + case null => + acc.reverse + case o if (f isDefinedAt o) => + doit(f(o) :: acc) + case o => + queue.offerFirst(o) + acc.reverse + } + } + + val ret = doit(Nil) + lastSoftTimeout = now + ret + } + + private def receiveN(n : Int, stop : Duration) : Seq[AnyRef] = { + for { x <- 1 to n } yield { + val timeout = stop - now + val o = receiveOne(timeout) + assert (o ne null, "timeout while expecting "+n+" messages") + o + } + } + + private def receiveOne(max : Duration) : AnyRef = { + if (max == 0.seconds) { + queue.pollFirst + } else if (max.finite_?) { + queue.pollFirst(max.length, max.unit) + } else { + queue.takeFirst + } + } + + private def format(u : TimeUnit, d : Duration) = "%.3f %s".format(d.toUnit(u), u.toString.toLowerCase) +} + +// vim: set ts=2 sw=2 et: diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala index d948865002..b25e96e87c 100644 --- a/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMActorSpec.scala @@ -6,19 +6,24 @@ package akka.actor import org.scalatest.junit.JUnitSuite import org.junit.Test +import FSM._ import org.multiverse.api.latches.StandardLatch import java.util.concurrent.TimeUnit +import akka.util.duration._ + object FSMActorSpec { - import FSM._ + val unlockedLatch = new StandardLatch val lockedLatch = new StandardLatch val unhandledLatch = new StandardLatch val terminatedLatch = new StandardLatch val transitionLatch = new StandardLatch + val initialStateLatch = new StandardLatch + val transitionCallBackLatch = new StandardLatch sealed trait LockState case object Locked extends LockState @@ -26,11 +31,8 @@ object FSMActorSpec { class Lock(code: String, timeout: (Long, TimeUnit)) extends Actor with FSM[LockState, CodeState] { - notifying { - case Transition(Locked, Open) => transitionLatch.open - case Transition(_, _) => () - } - + startWith(Locked, CodeState("", code)) + when(Locked) { case Event(digit: Char, CodeState(soFar, code)) => { soFar + digit match { @@ -47,7 +49,7 @@ object FSMActorSpec { } } case Event("hello", _) => stay replying "world" - case Event("bye", _) => stop + case Event("bye", _) => stop(Shutdown) } when(Open) { @@ -57,8 +59,6 @@ object FSMActorSpec { } } - startWith(Locked, CodeState("", code)) - whenUnhandled { case Event(_, stateData) => { log.slf4j.info("Unhandled") @@ -67,10 +67,21 @@ object FSMActorSpec { } } - onTermination { - case reason => terminatedLatch.open + onTransition(transitionHandler) + + def transitionHandler(from: LockState, to: LockState) = { + if (from == Locked && to == Open) transitionLatch.open } + onTermination { + case StopEvent(Shutdown, Locked, _) => + // stop is called from lockstate with shutdown as reason... + terminatedLatch.open + } + + // initialize the lock + initialize + private def doLock() { log.slf4j.info("Locked") lockedLatch.open @@ -88,12 +99,21 @@ object FSMActorSpec { class FSMActorSpec extends JUnitSuite { import FSMActorSpec._ + @Test def unlockTheLock = { // lock that locked after being open for 1 sec val lock = Actor.actorOf(new Lock("33221", (1, TimeUnit.SECONDS))).start + val transitionTester = Actor.actorOf(new Actor { def receive = { + case Transition(_, _, _) => transitionCallBackLatch.open + case CurrentState(_, Locked) => initialStateLatch.open + }}).start + + lock ! SubscribeTransitionCallBack(transitionTester) + assert(initialStateLatch.tryAwait(1, TimeUnit.SECONDS)) + lock ! '3' lock ! '3' lock ! '2' @@ -102,8 +122,10 @@ class FSMActorSpec extends JUnitSuite { assert(unlockedLatch.tryAwait(1, TimeUnit.SECONDS)) assert(transitionLatch.tryAwait(1, TimeUnit.SECONDS)) + assert(transitionCallBackLatch.tryAwait(1, TimeUnit.SECONDS)) assert(lockedLatch.tryAwait(2, TimeUnit.SECONDS)) + lock ! "not_handled" assert(unhandledLatch.tryAwait(2, TimeUnit.SECONDS)) diff --git a/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala b/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala new file mode 100644 index 0000000000..b13a61b82f --- /dev/null +++ b/akka-actor/src/test/scala/akka/actor/actor/FSMTimingSpec.scala @@ -0,0 +1,142 @@ +package akka.actor + +import akka.util.TestKit +import akka.util.duration._ + +import org.scalatest.WordSpec +import org.scalatest.matchers.MustMatchers + +class FSMTimingSpec + extends WordSpec + with MustMatchers + with TestKit { + + import FSMTimingSpec._ + import FSM._ + + val fsm = Actor.actorOf(new StateMachine(testActor)).start + fsm ! SubscribeTransitionCallBack(testActor) + expectMsg(100 millis, CurrentState(fsm, Initial)) + + ignoreMsg { + case Transition(_, Initial, _) => true + } + + "A Finite State Machine" must { + + "receive StateTimeout" in { + within (50 millis, 150 millis) { + fsm ! TestStateTimeout + expectMsg(Transition(fsm, TestStateTimeout, Initial)) + expectNoMsg + } + } + + "receive single-shot timer" in { + within (50 millis, 150 millis) { + fsm ! TestSingleTimer + expectMsg(Tick) + expectMsg(Transition(fsm, TestSingleTimer, Initial)) + expectNoMsg + } + } + + "receive and cancel a repeated timer" in { + fsm ! TestRepeatedTimer + val seq = receiveWhile(550 millis) { + case Tick => Tick + } + seq must have length (5) + within(250 millis) { + fsm ! Cancel + expectMsg(Transition(fsm, TestRepeatedTimer, Initial)) + expectNoMsg + } + } + + "notify unhandled messages" in { + fsm ! TestUnhandled + within(100 millis) { + fsm ! Tick + expectNoMsg + } + within(100 millis) { + fsm ! SetHandler + fsm ! Tick + expectMsg(Unhandled(Tick)) + expectNoMsg + } + within(100 millis) { + fsm ! Unhandled("test") + expectNoMsg + } + within(100 millis) { + fsm ! Cancel + expectMsg(Transition(fsm, TestUnhandled, Initial)) + } + } + + } + +} + +object FSMTimingSpec { + + trait State + case object Initial extends State + case object TestStateTimeout extends State + case object TestSingleTimer extends State + case object TestRepeatedTimer extends State + case object TestUnhandled extends State + + case object Tick + case object Cancel + case object SetHandler + + case class Unhandled(msg : AnyRef) + + class StateMachine(tester : ActorRef) extends Actor with FSM[State, Unit] { + import FSM._ + + startWith(Initial, ()) + when(Initial) { + case Ev(TestSingleTimer) => + setTimer("tester", Tick, 100 millis, false) + goto(TestSingleTimer) + case Ev(TestRepeatedTimer) => + setTimer("tester", Tick, 100 millis, true) + goto(TestRepeatedTimer) + case Ev(x : FSMTimingSpec.State) => goto(x) + } + when(TestStateTimeout, stateTimeout = 100 millis) { + case Ev(StateTimeout) => goto(Initial) + } + when(TestSingleTimer) { + case Ev(Tick) => + tester ! Tick + goto(Initial) + } + when(TestRepeatedTimer) { + case Ev(Tick) => + tester ! Tick + stay + case Ev(Cancel) => + cancelTimer("tester") + goto(Initial) + } + when(TestUnhandled) { + case Ev(SetHandler) => + whenUnhandled { + case Ev(Tick) => + tester ! Unhandled(Tick) + stay + } + stay + case Ev(Cancel) => + goto(Initial) + } + } + +} + +// vim: set ts=2 sw=2 et: diff --git a/akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala b/akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala new file mode 100644 index 0000000000..8c232255d0 --- /dev/null +++ b/akka-samples/akka-sample-fsm/src/main/scala/Buncher.scala @@ -0,0 +1,75 @@ +package sample.fsm.buncher + +import scala.reflect.ClassManifest +import akka.util.Duration +import akka.actor.{FSM, Actor} + +/* +* generic typed object buncher. +* +* To instantiate it, use the factory method like so: +* Buncher(100, 500)(x : List[AnyRef] => x foreach println) +* which will yield a fully functional and started ActorRef. +* The type of messages allowed is strongly typed to match the +* supplied processing method; other messages are discarded (and +* possibly logged). +*/ +object Buncher { + trait State + case object Idle extends State + case object Active extends State + + case object Flush // send out current queue immediately + case object Stop // poison pill + + case class Data[A](start : Long, xs : List[A]) + + def apply[A : Manifest](singleTimeout : Duration, + multiTimeout : Duration)(f : List[A] => Unit) = + Actor.actorOf(new Buncher[A](singleTimeout, multiTimeout).deliver(f)) +} + +class Buncher[A : Manifest] private (val singleTimeout : Duration, val multiTimeout : Duration) + extends Actor with FSM[Buncher.State, Buncher.Data[A]] { + import Buncher._ + import FSM._ + + private val manifestA = manifest[A] + + private var send : List[A] => Unit = _ + private def deliver(f : List[A] => Unit) = { send = f; this } + + private def now = System.currentTimeMillis + private def check(m : AnyRef) = ClassManifest.fromClass(m.getClass) <:< manifestA + + startWith(Idle, Data(0, Nil)) + + when(Idle) { + case Event(m : AnyRef, _) if check(m) => + goto(Active) using Data(now, m.asInstanceOf[A] :: Nil) + case Event(Flush, _) => stay + case Event(Stop, _) => stop + } + + when(Active, stateTimeout = Some(singleTimeout)) { + case Event(m : AnyRef, Data(start, xs)) if check(m) => + val l = m.asInstanceOf[A] :: xs + if (now - start > multiTimeout.toMillis) { + send(l.reverse) + goto(Idle) using Data(0, Nil) + } else { + stay using Data(start, l) + } + case Event(StateTimeout, Data(_, xs)) => + send(xs.reverse) + goto(Idle) using Data(0, Nil) + case Event(Flush, Data(_, xs)) => + send(xs.reverse) + goto(Idle) using Data(0, Nil) + case Event(Stop, Data(_, xs)) => + send(xs.reverse) + stop + } + + initialize +} diff --git a/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala b/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala index 4104f6c18f..7a4641c35e 100644 --- a/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala +++ b/akka-samples/akka-sample-fsm/src/main/scala/DiningHakkersOnFsm.scala @@ -3,8 +3,8 @@ package sample.fsm.dining.fsm import akka.actor.{ActorRef, Actor, FSM} import akka.actor.FSM._ import Actor._ -import java.util.concurrent.TimeUnit -import TimeUnit._ +import akka.util.Duration +import akka.util.duration._ /* * Some messages for the chopstick @@ -33,6 +33,9 @@ case class TakenBy(hakker: Option[ActorRef]) class Chopstick(name: String) extends Actor with FSM[ChopstickState, TakenBy] { self.id = name + // A chopstick begins its existence as available and taken by no one + startWith(Available, TakenBy(None)) + // When a chopstick is available, it can be taken by a some hakker when(Available) { case Event(Take, _) => @@ -49,8 +52,8 @@ class Chopstick(name: String) extends Actor with FSM[ChopstickState, TakenBy] { goto(Available) using TakenBy(None) } - // A chopstick begins its existence as available and taken by no one - startWith(Available, TakenBy(None)) + // Initialze the chopstick + initialize } /** @@ -81,10 +84,13 @@ case class TakenChopsticks(left: Option[ActorRef], right: Option[ActorRef]) class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor with FSM[FSMHakkerState, TakenChopsticks] { self.id = name + //All hakkers start waiting + startWith(Waiting, TakenChopsticks(None, None)) + when(Waiting) { case Event(Think, _) => log.info("%s starts to think", name) - startThinking(5, SECONDS) + startThinking(5 seconds) } //When a hakker is thinking it can become hungry @@ -118,12 +124,12 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit case Event(Busy(chopstick), TakenChopsticks(leftOption, rightOption)) => leftOption.foreach(_ ! Put) rightOption.foreach(_ ! Put) - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) } private def startEating(left: ActorRef, right: ActorRef): State = { log.info("%s has picked up %s and %s, and starts to eat", name, left.id, right.id) - goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5, SECONDS) + goto(Eating) using TakenChopsticks(Some(left), Some(right)) forMax (5 seconds) } // When the results of the other grab comes back, @@ -132,9 +138,9 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit when(FirstChopstickDenied) { case Event(Taken(secondChopstick), _) => secondChopstick ! Put - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) case Event(Busy(chopstick), _) => - startThinking(10, MILLISECONDS) + startThinking(10 milliseconds) } // When a hakker is eating, he can decide to start to think, @@ -144,15 +150,15 @@ class FSMHakker(name: String, left: ActorRef, right: ActorRef) extends Actor wit log.info("%s puts down his chopsticks and starts to think", name) left ! Put right ! Put - startThinking(5, SECONDS) + startThinking(5 seconds) } - private def startThinking(period: Int, timeUnit: TimeUnit): State = { - goto(Thinking) using TakenChopsticks(None, None) forMax (period, timeUnit) - } + // Initialize the hakker + initialize - //All hakkers start waiting - startWith(Waiting, TakenChopsticks(None, None)) + private def startThinking(duration: Duration): State = { + goto(Thinking) using TakenChopsticks(None, None) forMax duration + } } /* diff --git a/akka-stm/src/test/java/akka/transactor/test/UntypedCoordinatedCounter.java b/akka-stm/src/test/java/akka/transactor/test/UntypedCoordinatedCounter.java index b1030106de..9e36409728 100644 --- a/akka-stm/src/test/java/akka/transactor/test/UntypedCoordinatedCounter.java +++ b/akka-stm/src/test/java/akka/transactor/test/UntypedCoordinatedCounter.java @@ -5,7 +5,7 @@ import akka.transactor.Atomically; import akka.actor.ActorRef; import akka.actor.UntypedActor; import akka.stm.*; -import akka.util.Duration; +import akka.util.FiniteDuration; import org.multiverse.api.StmUtils; @@ -17,7 +17,7 @@ public class UntypedCoordinatedCounter extends UntypedActor { private String name; private Ref count = new Ref(0); private TransactionFactory txFactory = new TransactionFactoryBuilder() - .setTimeout(new Duration(3, TimeUnit.SECONDS)) + .setTimeout(new FiniteDuration(3, TimeUnit.SECONDS)) .build(); public UntypedCoordinatedCounter(String name) { diff --git a/akka-stm/src/test/java/akka/transactor/test/UntypedCounter.java b/akka-stm/src/test/java/akka/transactor/test/UntypedCounter.java index d343ceea31..325b06ba73 100644 --- a/akka-stm/src/test/java/akka/transactor/test/UntypedCounter.java +++ b/akka-stm/src/test/java/akka/transactor/test/UntypedCounter.java @@ -4,7 +4,7 @@ import akka.transactor.UntypedTransactor; import akka.transactor.SendTo; import akka.actor.ActorRef; import akka.stm.*; -import akka.util.Duration; +import akka.util.FiniteDuration; import org.multiverse.api.StmUtils; @@ -23,7 +23,7 @@ public class UntypedCounter extends UntypedTransactor { @Override public TransactionFactory transactionFactory() { return new TransactionFactoryBuilder() - .setTimeout(new Duration(3, TimeUnit.SECONDS)) + .setTimeout(new FiniteDuration(3, TimeUnit.SECONDS)) .build(); } @@ -74,4 +74,4 @@ public class UntypedCounter extends UntypedTransactor { return true; } else return false; } -} \ No newline at end of file +}