diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala new file mode 100644 index 0000000000..3d217dfcfd --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala @@ -0,0 +1,201 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import akka.actor.{ NoSerializationVerificationNeeded, ActorRef } +import akka.stream.scaladsl.{ Keep, Source, Sink } +import akka.stream.{ Attributes, Inlet, SinkShape, ActorMaterializer } +import akka.stream.stage.{ InHandler, AsyncCallback, GraphStageLogic, GraphStageWithMaterializedValue } +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.Utils._ + +import scala.concurrent.{ Await, Promise, Future } +import scala.concurrent.duration._ + +class KeepGoingStageSpec extends AkkaSpec { + + implicit val mat = ActorMaterializer() + + trait PingCmd extends NoSerializationVerificationNeeded + case class Register(probe: ActorRef) extends PingCmd + case object Ping extends PingCmd + case object CompleteStage extends PingCmd + case object FailStage extends PingCmd + case object Throw extends PingCmd + + trait PingEvt extends NoSerializationVerificationNeeded + case object Pong extends PingEvt + case object PostStop extends PingEvt + case object UpstreamCompleted extends PingEvt + case object EndOfEventHandler extends PingEvt + + case class PingRef(private val cb: AsyncCallback[PingCmd]) { + def register(probe: ActorRef): Unit = cb.invoke(Register(probe)) + def ping(): Unit = cb.invoke(Ping) + def stop(): Unit = cb.invoke(CompleteStage) + def fail(): Unit = cb.invoke(FailStage) + def throwEx(): Unit = cb.invoke(Throw) + } + + class PingableSink(keepAlive: Boolean) extends GraphStageWithMaterializedValue[SinkShape[Int], Future[PingRef]] { + val shape = SinkShape[Int](Inlet("ping.in")) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[PingRef]) = { + val promise = Promise[PingRef]() + + val logic = new GraphStageLogic(shape) { + private var listener: Option[ActorRef] = None + + override def preStart(): Unit = { + promise.trySuccess(PingRef(getAsyncCallback(onCommand))) + } + + private def onCommand(cmd: PingCmd): Unit = cmd match { + case Register(probe) ⇒ listener = Some(probe) + case Ping ⇒ listener.foreach(_ ! Pong) + case CompleteStage ⇒ + completeStage() + listener.foreach(_ ! EndOfEventHandler) + case FailStage ⇒ + failStage(TE("test")) + listener.foreach(_ ! EndOfEventHandler) + case Throw ⇒ + try { + throw TE("test") + } finally listener.foreach(_ ! EndOfEventHandler) + } + + setHandler(shape.inlet, new InHandler { + override def onPush(): Unit = pull(shape.inlet) + + // Ignore finish + override def onUpstreamFinish(): Unit = listener.foreach(_ ! UpstreamCompleted) + }) + + override def keepGoingAfterAllPortsClosed: Boolean = keepAlive + + override def postStop(): Unit = listener.foreach(_ ! PostStop) + } + + (logic, promise.future) + } + } + + "A stage with keep-going" must { + + "still be alive after all ports have been closed until explicitly closed" in assertAllStagesStopped { + val (maybePromise, pingerFuture) = Source.maybe[Int].toMat(new PingableSink(keepAlive = true))(Keep.both).run() + val pinger = Await.result(pingerFuture, 3.seconds) + + pinger.register(testActor) + + // Before completion + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + maybePromise.trySuccess(None) + expectMsg(UpstreamCompleted) + + expectNoMsg(200.millis) + + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + pinger.stop() + // PostStop should not be concurrent with the event handler. This event here tests this. + expectMsg(EndOfEventHandler) + expectMsg(PostStop) + + } + + "still be alive after all ports have been closed until explicitly failed" in assertAllStagesStopped { + val (maybePromise, pingerFuture) = Source.maybe[Int].toMat(new PingableSink(keepAlive = true))(Keep.both).run() + val pinger = Await.result(pingerFuture, 3.seconds) + + pinger.register(testActor) + + // Before completion + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + maybePromise.trySuccess(None) + expectMsg(UpstreamCompleted) + + expectNoMsg(200.millis) + + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + pinger.fail() + // PostStop should not be concurrent with the event handler. This event here tests this. + expectMsg(EndOfEventHandler) + expectMsg(PostStop) + + } + + "still be alive after all ports have been closed until implicitly failed (via exception)" in assertAllStagesStopped { + val (maybePromise, pingerFuture) = Source.maybe[Int].toMat(new PingableSink(keepAlive = true))(Keep.both).run() + val pinger = Await.result(pingerFuture, 3.seconds) + + pinger.register(testActor) + + // Before completion + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + maybePromise.trySuccess(None) + expectMsg(UpstreamCompleted) + + expectNoMsg(200.millis) + + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + pinger.throwEx() + // PostStop should not be concurrent with the event handler. This event here tests this. + expectMsg(EndOfEventHandler) + expectMsg(PostStop) + + } + + "close down early if keepAlive is not requested" in assertAllStagesStopped { + val (maybePromise, pingerFuture) = Source.maybe[Int].toMat(new PingableSink(keepAlive = false))(Keep.both).run() + val pinger = Await.result(pingerFuture, 3.seconds) + + pinger.register(testActor) + + // Before completion + pinger.ping() + expectMsg(Pong) + + pinger.ping() + expectMsg(Pong) + + maybePromise.trySuccess(None) + expectMsg(UpstreamCompleted) + expectMsg(PostStop) + + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index e6f980ffb6..ae4f1174cf 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -357,7 +357,8 @@ private[stream] final class GraphInterpreter( // Counts how many active connections a stage has. Once it reaches zero, the stage is automatically stopped. private[this] val shutdownCounter = Array.tabulate(assembly.stages.length) { i ⇒ val shape = assembly.stages(i).shape - shape.inlets.size + shape.outlets.size + val keepGoing = if (logics(i).keepGoingAfterAllPortsClosed) 1 else 0 + shape.inlets.size + shape.outlets.size + keepGoing } // An event queue implemented as a circular buffer @@ -618,6 +619,12 @@ private[stream] final class GraphInterpreter( } } + // Call only for keep-alive stages + def closeKeptAliveStageIfNeeded(stageId: Int): Unit = + if (stageId != Boundary && shutdownCounter(stageId) == 1) { + shutdownCounter(stageId) = 0 + } + private def finalizeStage(logic: GraphStageLogic): Unit = { try { logic.postStop() diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 7d62676c3a..e5cdef8e17 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -149,8 +149,9 @@ object GraphStageLogic { * * The lifecycle hooks [[preStart()]] and [[postStop()]] * * Methods for performing stream processing actions, like pulling or pushing elements * - * The stage logic is always stopped once all its input and output ports have been closed, i.e. it is not possible to - * keep the stage alive for further processing once it does not have any open ports. + * The stage logic is always once all its input and output ports have been closed, i.e. it is not possible to + * keep the stage alive for further processing once it does not have any open ports. This can be changed by + * overriding `keepGoingAfterAllPortsClosed` to return true. */ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) { import GraphInterpreter._ @@ -394,6 +395,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } i += 1 } + if (keepGoingAfterAllPortsClosed) interpreter.closeKeptAliveStageIfNeeded(stageId) } /** @@ -409,6 +411,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: interpreter.fail(portToConn(i), ex) i += 1 } + if (keepGoingAfterAllPortsClosed) interpreter.closeKeptAliveStageIfNeeded(stageId) } /** @@ -718,6 +721,12 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Invoked after processing of external events stopped because the stage is about to stop or fail. */ def postStop(): Unit = () + + /** + * If this method returns true when all ports had been closed then the stage is not stopped until + * completeStage() or failStage() are explicitly called + */ + def keepGoingAfterAllPortsClosed: Boolean = false } /**