diff --git a/akka-actor-tests/src/test/scala/akka/actor/FSMActorSpec.scala b/akka-actor-tests/src/test/scala/akka/actor/FSMActorSpec.scala index 07dcd4c02b..3de2186dba 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/FSMActorSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/FSMActorSpec.scala @@ -4,6 +4,8 @@ package akka.actor +import akka.actor.FSM.StateTimeout + import language.postfixOps import org.scalatest.{ BeforeAndAfterAll, BeforeAndAfterEach } import akka.testkit._ @@ -340,6 +342,40 @@ class FSMActorSpec extends AkkaSpec(Map("akka.actor.debug.fsm" -> true)) with Im expectMsg(Transition(fsmref, 0, 1)) } + "allow cancelling stateTimeout by issuing forMax(Duration.Inf)" in { + val sys = ActorSystem("fsmEvent") + val p = TestProbe()(sys) + + val OverrideTimeoutToInf = "override-timeout-to-inf" + + val fsm = sys.actorOf(Props(new Actor with FSM[String, String] { + + startWith("init", "") + + when("init", stateTimeout = 1.second) { + case Event(StateTimeout, _) ⇒ + p.ref ! StateTimeout + stay() + + case Event(OverrideTimeoutToInf, _) ⇒ + p.ref ! OverrideTimeoutToInf + stay() forMax Duration.Inf + } + + initialize() + })) + + try { + p.expectMsg(FSM.StateTimeout) + + fsm ! OverrideTimeoutToInf + p.expectMsg(OverrideTimeoutToInf) + p.expectNoMsg(3.seconds) + } finally { + TestKit.shutdownActorSystem(sys) + } + } + } } diff --git a/akka-actor/src/main/scala/akka/actor/FSM.scala b/akka-actor/src/main/scala/akka/actor/FSM.scala index c87fbb355d..62163883bf 100644 --- a/akka-actor/src/main/scala/akka/actor/FSM.scala +++ b/akka-actor/src/main/scala/akka/actor/FSM.scala @@ -8,6 +8,7 @@ import scala.concurrent.duration.Duration import scala.collection.mutable import akka.routing.{ Deafen, Listen, Listeners } import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration._ object FSM { @@ -118,6 +119,9 @@ object FSM { */ final case class LogEntry[S, D](stateName: S, stateData: D, event: Any) + /** Used by `forMax` to signal "cancel stateTimeout" */ + private final val SomeMaxFiniteDuration = Some(Long.MaxValue.nanos) + /** * This captures all of the managed state of the [[akka.actor.FSM]]: the state * name, the state data, possibly custom timeout, stop reason and replies @@ -141,8 +145,10 @@ object FSM { */ def forMax(timeout: Duration): State[S, D] = timeout match { case f: FiniteDuration ⇒ copy(timeout = Some(f)) - case _ ⇒ copy(timeout = None) - } + case Duration.Inf ⇒ copy(timeout = SomeMaxFiniteDuration) // we map the Infinite duration to a special marker, + case _ ⇒ copy(timeout = None) // that means "cancel stateTimeout". This marker is needed + } // so we do not have to break source/binary compat. + // TODO: Can be removed once we can break State#timeout signature to `Option[Duration]` /** * Send reply to sender of the current message, if available. @@ -648,13 +654,18 @@ trait FSM[S, D] extends Actor with Listeners with ActorLogging { this.nextState = null } currentState = nextState - val timeout = if (currentState.timeout.isDefined) currentState.timeout else stateTimeouts(currentState.stateName) - if (timeout.isDefined) { - val t = timeout.get - if (t.isFinite && t.length >= 0) { - import context.dispatcher - timeoutFuture = Some(context.system.scheduler.scheduleOnce(t, self, TimeoutMarker(generation))) - } + + def scheduleTimeout(d: FiniteDuration): Some[Cancellable] = { + import context.dispatcher + Some(context.system.scheduler.scheduleOnce(d, self, TimeoutMarker(generation))) + } + + currentState.timeout match { + case SomeMaxFiniteDuration ⇒ // effectively disable stateTimeout + case Some(d: FiniteDuration) if d.length >= 0 ⇒ timeoutFuture = scheduleTimeout(d) + case _ ⇒ + val timeout = stateTimeouts(currentState.stateName) + if (timeout.isDefined) timeoutFuture = scheduleTimeout(timeout.get) } } } @@ -1126,7 +1137,7 @@ abstract class AbstractFSM[S, D] extends FSM[S, D] { * @return current state descriptor */ final def setTimer(name: String, msg: Any, timeout: FiniteDuration): Unit = - setTimer(name, msg, timeout, false); + setTimer(name, msg, timeout, false) /** * Default reason if calling `stop()`.