diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala index 4ca5cd9195..b7079e6450 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala @@ -95,11 +95,14 @@ object GraphInterpreterSpecKit { stages: Array[GraphStageWithMaterializedValue[_ <: Shape, _]], upstreams: Array[UpstreamBoundaryStageLogic[_]], downstreams: Array[DownstreamBoundaryStageLogic[_]], - attributes: Array[Attributes] = Array.empty) + attributes: Array[Attributes] = Array.empty)(implicit system: ActorSystem) : (Array[GraphStageLogic], SMap[Inlet[_], GraphStageLogic], SMap[Outlet[_], GraphStageLogic]) = { if (attributes.nonEmpty && attributes.length != stages.length) throw new IllegalArgumentException("Attributes must be either empty or one per stage") + @silent("deprecated") + val defaultAttributes = ActorMaterializerSettings(system).toAttributes + var inOwners = SMap.empty[Inlet[_], GraphStageLogic] var outOwners = SMap.empty[Outlet[_], GraphStageLogic] @@ -108,6 +111,7 @@ object GraphInterpreterSpecKit { while (idx < upstreams.length) { val upstream = upstreams(idx) + upstream.attributes = defaultAttributes upstream.stageId = idx logics(idx) = upstream upstream.out.id = 0 @@ -120,11 +124,13 @@ object GraphInterpreterSpecKit { val stage = stages(stageIdx) setPortIds(stage.shape) - val stageAttributes = + val stageAttributes = defaultAttributes and { if (attributes.nonEmpty) stage.traversalBuilder.attributes and attributes(stageIdx) else stage.traversalBuilder.attributes + } val logic = stage.createLogicAndMaterializedValue(stageAttributes)._1 + logic.attributes = stageAttributes logic.stageId = idx var inletIdx = 0 @@ -151,6 +157,7 @@ object GraphInterpreterSpecKit { var downstreamIdx = 0 while (downstreamIdx < downstreams.length) { val downstream = downstreams(downstreamIdx) + downstream.attributes = defaultAttributes downstream.stageId = idx logics(idx) = downstream downstream.in.id = 0 @@ -243,6 +250,8 @@ trait GraphInterpreterSpecKit extends StreamSpec { import GraphInterpreterSpecKit._ val logger = Logging(system, "InterpreterSpecKit") + @silent("deprecated") + val defaultAttributes = ActorMaterializerSettings(system).toAttributes abstract class Builder { private var _interpreter: GraphInterpreter = _ @@ -312,6 +321,11 @@ trait GraphInterpreterSpecKit extends StreamSpec { } def manualInit(logics: Array[GraphStageLogic], connections: Array[Connection]): Unit = { + // set some default attributes where missing + logics.foreach { l => + if (l.attributes == Attributes.none) l.attributes = defaultAttributes + } + _interpreter = new GraphInterpreter( NoMaterializer, logger, diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/CancellationStrategySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/CancellationStrategySpec.scala new file mode 100644 index 0000000000..e78be45765 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/CancellationStrategySpec.scala @@ -0,0 +1,276 @@ +/* + * Copyright (C) 2015-2019 Lightbend Inc. + */ + +package akka.stream.scaladsl + +import akka.NotUsed +import akka.stream.Attributes +import akka.stream.Attributes.CancellationStrategy +import akka.stream.Attributes.CancellationStrategy.FailStage +import akka.stream.BidiShape +import akka.stream.ClosedShape +import akka.stream.Inlet +import akka.stream.Materializer +import akka.stream.Outlet +import akka.stream.SharedKillSwitch +import akka.stream.SubscriptionWithCancelException +import akka.stream.UniformFanOutShape +import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage +import akka.stream.stage.GraphStage +import akka.stream.stage.GraphStageLogic +import akka.stream.stage.InHandler +import akka.stream.stage.OutHandler +import akka.stream.stage.StageLogging +import akka.stream.testkit.StreamSpec +import akka.stream.testkit.TestPublisher +import akka.stream.testkit.TestSubscriber +import akka.stream.testkit.Utils.TE +import akka.testkit.WithLogCapturing +import akka.testkit._ + +import scala.concurrent.duration._ + +class CancellationStrategySpec extends StreamSpec("""akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"]""") with WithLogCapturing { + "CancellationStrategyAttribute" should { + "support strategies" should { + "CompleteStage" should { + "complete if no failure cancellation" in new TestSetup(CancellationStrategy.CompleteStage) { + out1Probe.cancel() + inProbe.expectCancellation() + out2Probe.expectComplete() + } + "complete and propagate cause if failure cancellation" in new TestSetup(CancellationStrategy.CompleteStage) { + val theError = TE("This is a TestException") + out1Probe.cancel(theError) + inProbe.expectCancellationWithCause(theError) + out2Probe.expectComplete() + } + } + "FailStage" should { + "fail if no failure cancellation" in new TestSetup(CancellationStrategy.FailStage) { + out1Probe.cancel() + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + out2Probe.expectError(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + "fail if failure cancellation" in new TestSetup(CancellationStrategy.FailStage) { + val theError = TE("This is a TestException") + out1Probe.cancel(theError) + inProbe.expectCancellationWithCause(theError) + out2Probe.expectError(theError) + } + } + "PropagateFailure" should { + "complete if no failure" in new TestSetup(CancellationStrategy.PropagateFailure) { + out1Probe.cancel() + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + out2Probe.expectComplete() + } + "propagate failure" in new TestSetup(CancellationStrategy.PropagateFailure) { + val theError = TE("This is a TestException") + out1Probe.cancel(theError) + inProbe.expectCancellationWithCause(theError) + out2Probe.expectError(theError) + } + } + "AfterDelay" should { + "apply given strategy after delay" in new TestSetup(CancellationStrategy.AfterDelay(500.millis, FailStage)) { + out1Probe.cancel() + inProbe.expectNoMessage(200.millis) + out2Probe.expectNoMessage(200.millis) + + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + out2Probe.expectError(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + "prevent further elements from coming through" in new TestSetup( + CancellationStrategy.AfterDelay(500.millis, FailStage)) { + out1Probe.request(1) + out2Probe.request(1) + out1Probe.cancel() + inProbe.sendNext(B(123)) + inProbe.expectNoMessage(200.millis) // cancellation should not have propagated yet + out2Probe.expectNext(B(123)) // so the element still goes to out2 + out1Probe.expectNoMessage(200.millis) // but not to out1 which has already cancelled + + // after delay cancellation and error should have propagated + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + out2Probe.expectError(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + } + } + + "cancellation races with BidiStacks" should { + "accidentally convert errors to completions when CompleteStage strategy is chosen (2.5 default)" in new RaceTestSetup( + CancellationStrategy.CompleteStage) { + val theError = TE("Duck meowed") + killSwitch.abort(theError) + toStream.expectCancellationWithCause(theError) + + // this asserts the previous broken behavior (which can still be seen with CompleteStage strategy) + fromStream.expectComplete() + } + "be prevented by PropagateFailure strategy (default in 2.6)" in new RaceTestSetup( + CancellationStrategy.PropagateFailure) { + val theError = TE("Duck meowed") + killSwitch.abort(theError) + toStream.expectCancellationWithCause(theError) + fromStream.expectError(theError) + } + "be prevented by AfterDelay strategy" in new RaceTestSetup( + CancellationStrategy.AfterDelay(500.millis.dilated, CancellationStrategy.CompleteStage)) { + val theError = TE("Duck meowed") + killSwitch.abort(theError) + toStream.expectCancellationWithCause(theError) + fromStream.expectError(theError) + } + + class RaceTestSetup(cancellationStrategy: CancellationStrategy.Strategy) { + val toStream = TestPublisher.probe[A]() + val fromStream = TestSubscriber.probe[B]() + + val bidi: BidiFlow[A, A, B, B, NotUsed] = BidiFlow.fromGraph(new NaiveBidiStage) + + val killSwitch = new SharedKillSwitch("test") + def errorPropagationDelay: FiniteDuration = 200.millis.dilated + + Source + .fromPublisher(toStream) + .via( + bidi + .atop(BidiFlow.fromFlows( + new DelayCompletionSignal[A](errorPropagationDelay), + new DelayCompletionSignal[B](errorPropagationDelay))) + .join(Flow[A].via(killSwitch.flow).map(_.toB))) + .to(Sink.fromSubscriber(fromStream)) + .addAttributes(Attributes(CancellationStrategy(cancellationStrategy))) // fails for `CompleteStage` + .run() + + fromStream.request(1) + toStream.sendNext(A("125")) + fromStream.expectNext(B(125)) + } + } + } + + case class A(str: String) { + def toB: B = B(str.toInt) + } + case class B(i: Int) + + class TestSetup(cancellationStrategy: Option[CancellationStrategy.Strategy]) { + def this(strategy: CancellationStrategy.Strategy) = this(Some(strategy)) + + val inProbe = TestPublisher.probe[B]() + val out1Probe = TestSubscriber.probe[B]() + val out2Probe = TestSubscriber.probe[B]() + + def materializer: Materializer = Materializer.matFromSystem(system) + + RunnableGraph + .fromGraph { + GraphDSL.create() { implicit b => + import GraphDSL.Implicits._ + val fanOut = b.add(new TestFanOut) + + Source.fromPublisher(inProbe) ~> fanOut.in + fanOut.out(0) ~> Sink.fromSubscriber(out1Probe) + fanOut.out(1) ~> Sink.fromSubscriber(out2Probe) + + ClosedShape + } + } + .addAttributes(Attributes(cancellationStrategy.toList.map(CancellationStrategy(_)))) + .run()(materializer) + + // some basic testing that data flow + out1Probe.request(1) + out2Probe.request(1) + + inProbe.expectRequest() + inProbe.sendNext(B(42)) + out1Probe.expectNext(B(42)) + out2Probe.expectNext(B(42)) + } + + // a simple broadcast stage + class TestFanOut extends GraphStage[UniformFanOutShape[B, B]] { + val in = Inlet[B]("in") + val out1 = Outlet[B]("out1") + val out2 = Outlet[B]("out2") + + val shape = UniformFanOutShape(in, out1, out2) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler with StageLogging { + setHandler(in, this) + setHandler(out1, this) + setHandler(out2, this) + + var waitingForPulls = 2 + override def onPush(): Unit = { + val el = grab(in) + push(out1, el) + push(out2, el) + waitingForPulls = 2 + } + + override def onPull(): Unit = { + waitingForPulls -= 1 + require(waitingForPulls >= 0) + if (waitingForPulls == 0) + pull(in) + } + } + } + class NaiveBidiStage extends GraphStage[BidiShape[A, A, B, B]] { + val upIn = Inlet[A]("upIn") + val upOut = Outlet[A]("upOut") + + val downIn = Inlet[B]("downIn") + val downOut = Outlet[B]("downOut") + + val shape = BidiShape(upIn, upOut, downIn, downOut) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with StageLogging { + def connect[T](in: Inlet[T], out: Outlet[T]): Unit = { + val handler = new InHandler with OutHandler { + override def onPull(): Unit = pull(in) + override def onPush(): Unit = push(out, grab(in)) + } + setHandlers(in, out, handler) + } + connect(upIn, upOut) + connect(downIn, downOut) + } + } + + /** A simple stage that delays completion signals */ + class DelayCompletionSignal[T](delay: FiniteDuration) extends SimpleLinearGraphStage[T] { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler with StageLogging { + setHandlers(in, out, this) + + override def onPull(): Unit = pull(in) + override def onPush(): Unit = push(out, grab(in)) + + val callback = getAsyncCallback[Option[Throwable]] { signal => + log.debug(s"Now executing delayed action $signal") + signal match { + case Some(ex) => failStage(ex) + case None => completeStage() + } + } + + override def onUpstreamFinish(): Unit = { + log.debug(s"delaying completion") + materializer.scheduleOnce(delay, () => callback.invoke(None)) + } + override def onUpstreamFailure(ex: Throwable): Unit = { + log.debug(s"delaying error $ex") + materializer.scheduleOnce(delay, () => callback.invoke(Some(ex))) + } + } + } +} diff --git a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala index ae23ecb98a..bd5f3a184b 100644 --- a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala @@ -743,6 +743,7 @@ final class ActorMaterializerSettings @InternalApi private ( // these are the core stream/materializer settings, ad hoc handling of defaults for the stage specific ones // for stream refs and io live with the respective stages Attributes.InputBuffer(initialInputBufferSize, maxInputBufferSize) :: + Attributes.CancellationStrategy.Default :: // FIXME: make configurable, see https://github.com/akka/akka/issues/28000 ActorAttributes.Dispatcher(dispatcher) :: ActorAttributes.SupervisionStrategy(supervisionDecider) :: ActorAttributes.DebugLogging(debugLogging) :: @@ -751,7 +752,9 @@ final class ActorMaterializerSettings @InternalApi private ( ActorAttributes.OutputBurstLimit(outputBurstLimit) :: ActorAttributes.FuzzingMode(fuzzingMode) :: ActorAttributes.MaxFixedBufferSize(maxFixedBufferSize) :: - ActorAttributes.SyncProcessingLimit(syncProcessingLimit) :: Nil) + ActorAttributes.SyncProcessingLimit(syncProcessingLimit) :: + + Nil) override def toString: String = s"ActorMaterializerSettings($initialInputBufferSize,$maxInputBufferSize," + diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index 55218e62bc..42d9f68fcd 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -14,6 +14,7 @@ import akka.japi.function import java.net.URLEncoder import java.time.Duration +import akka.annotation.ApiMayChange import akka.annotation.DoNotInherit import akka.annotation.InternalApi import akka.stream.impl.TraversalBuilder @@ -311,6 +312,136 @@ object Attributes { extends Attribute final case object AsyncBoundary extends Attribute + /** + * Cancellation strategies provide a way to configure the behavior of a stage when `cancelStage` is called. + * + * It is only relevant for stream components that have more than one output and do not define a custom cancellation + * behavior by overriding `onDownstreamFinish`. In those cases, if the first output is cancelled, the default behavior + * is to call `cancelStage` which shuts down the stage completely. The given strategy will allow customization of how + * the shutdown procedure should be done precisely. + */ + @ApiMayChange + final case class CancellationStrategy(strategy: CancellationStrategy.Strategy) extends MandatoryAttribute + @ApiMayChange + object CancellationStrategy { + private[stream] val Default: CancellationStrategy = CancellationStrategy(PropagateFailure) + + sealed trait Strategy + + /** + * Strategy that treats `cancelStage` the same as `completeStage`, i.e. all inlets are cancelled (propagating the + * cancellation cause) and all outlets are regularly completed. + * + * This used to be the default behavior before Akka 2.6. + * + * This behavior can be problematic in stacks of BidiFlows where different layers of the stack are both connected + * through inputs and outputs. In this case, an error in a doubly connected component triggers both a cancellation + * going upstream and an error going downstream. Since the stack might be connected to those components with inlets and + * outlets, a race starts whether the cancellation or the error arrives first. If the error arrives first, that's usually + * good because then the error can be propagated both on inlets and outlets. However, if the cancellation arrives first, + * the previous default behavior to complete the stage will lead other outputs to be completed regularly. The error + * which arrive late at the other hand will just be ignored (that connection will have been cancelled already and also + * the paths through which the error could propagates are already shut down). + */ + @ApiMayChange + case object CompleteStage extends Strategy + + /** + * Strategy that treats `cancelStage` the same as `failStage`, i.e. all inlets are cancelled (propagating the + * cancellation cause) and all outlets are failed propagating the cause from cancellation. + */ + @ApiMayChange + case object FailStage extends Strategy + + /** + * Strategy that treats `cancelStage` in different ways depending on the cause that was given to the cancellation. + * + * If the cause was a regular, active cancellation (`SubscriptionWithCancelException.NoMoreElementsNeeded`), the stage + * receiving this cancellation is completed regularly. + * + * If another cause was given, this is treated as an error and the behavior is the same as with `failStage`. + * + * This is a good default strategy. + */ + @ApiMayChange + case object PropagateFailure extends Strategy + + /** + * Strategy that allows to delay any action when `cancelStage` is invoked. + * + * The idea of this strategy is to delay any action on cancellation because it is expected that the stage is completed + * through another path in the meantime. The downside is that a stage and a stream may live longer than expected if no + * such signal is received and cancellation is invoked later on. In streams with many stages that all apply this strategy, + * this strategy might significantly delay the propagation of a cancellation signal because each upstream stage might impose + * such a delay. During this time, the stream will be mostly "silent", i.e. it cannot make progress because of backpressure, + * but you might still be able observe a long delay at the ultimate source. + */ + @ApiMayChange + final case class AfterDelay(delay: FiniteDuration, strategy: Strategy) extends Strategy + } + + /** + * Java API + * + * Strategy that treats `cancelStage` the same as `completeStage`, i.e. all inlets are cancelled (propagating the + * cancellation cause) and all outlets are regularly completed. + * + * This used to be the default behavior before Akka 2.6. + * + * This behavior can be problematic in stacks of BidiFlows where different layers of the stack are both connected + * through inputs and outputs. In this case, an error in a doubly connected component triggers both a cancellation + * going upstream and an error going downstream. Since the stack might be connected to those components with inlets and + * outlets, a race starts whether the cancellation or the error arrives first. If the error arrives first, that's usually + * good because then the error can be propagated both on inlets and outlets. However, if the cancellation arrives first, + * the previous default behavior to complete the stage will lead other outputs to be completed regularly. The error + * which arrive late at the other hand will just be ignored (that connection will have been cancelled already and also + * the paths through which the error could propagates are already shut down). + */ + @ApiMayChange + def cancellationStrategyCompleteState: CancellationStrategy.Strategy = CancellationStrategy.CompleteStage + + /** + * Java API + * + * Strategy that treats `cancelStage` the same as `failStage`, i.e. all inlets are cancelled (propagating the + * cancellation cause) and all outlets are failed propagating the cause from cancellation. + */ + @ApiMayChange + def cancellationStrategyFailStage: CancellationStrategy.Strategy = CancellationStrategy.FailStage + + /** + * Java API + * + * Strategy that treats `cancelStage` in different ways depending on the cause that was given to the cancellation. + * + * If the cause was a regular, active cancellation (`SubscriptionWithCancelException.NoMoreElementsNeeded`), the stage + * receiving this cancellation is completed regularly. + * + * If another cause was given, this is treated as an error and the behavior is the same as with `failStage`. + * + * This is a good default strategy. + */ + @ApiMayChange + def cancellationStrategyPropagateFailure: CancellationStrategy.Strategy = CancellationStrategy.PropagateFailure + + /** + * Java API + * + * Strategy that allows to delay any action when `cancelStage` is invoked. + * + * The idea of this strategy is to delay any action on cancellation because it is expected that the stage is completed + * through another path in the meantime. The downside is that a stage and a stream may live longer than expected if no + * such signal is received and cancellation is invoked later on. In streams with many stages that all apply this strategy, + * this strategy might significantly delay the propagation of a cancellation signal because each upstream stage might impose + * such a delay. During this time, the stream will be mostly "silent", i.e. it cannot make progress because of backpressure, + * but you might still be able observe a long delay at the ultimate source. + */ + @ApiMayChange + def cancellationStrategyAfterDelay( + delay: FiniteDuration, + strategy: CancellationStrategy.Strategy): CancellationStrategy.Strategy = + CancellationStrategy.AfterDelay(delay, strategy) + object LogLevels { /** Use to disable logging on certain operations when configuring [[Attributes#logLevels]] */ @@ -332,16 +463,16 @@ object Attributes { /** Java API: Use to disable logging on certain operations when configuring [[Attributes#createLogLevels]] */ def logLevelOff: Logging.LogLevel = LogLevels.Off - /** Use to enable logging at ERROR level for certain operations when configuring [[Attributes#createLogLevels]] */ + /** Java API: Use to enable logging at ERROR level for certain operations when configuring [[Attributes#createLogLevels]] */ def logLevelError: Logging.LogLevel = LogLevels.Error - /** Use to enable logging at WARNING level for certain operations when configuring [[Attributes#createLogLevels]] */ + /** Java API: Use to enable logging at WARNING level for certain operations when configuring [[Attributes#createLogLevels]] */ def logLevelWarning: Logging.LogLevel = LogLevels.Warning - /** Use to enable logging at INFO level for certain operations when configuring [[Attributes#createLogLevels]] */ + /** Java API: Use to enable logging at INFO level for certain operations when configuring [[Attributes#createLogLevels]] */ def logLevelInfo: Logging.LogLevel = LogLevels.Info - /** Use to enable logging at DEBUG level for certain operations when configuring [[Attributes#createLogLevels]] */ + /** Java API: Use to enable logging at DEBUG level for certain operations when configuring [[Attributes#createLogLevels]] */ def logLevelDebug: Logging.LogLevel = LogLevels.Debug /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/PhasedFusingActorMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl/PhasedFusingActorMaterializer.scala index 58efa98328..a10f6880fc 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/PhasedFusingActorMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/PhasedFusingActorMaterializer.scala @@ -750,6 +750,7 @@ private final case class SavedIslandData( val boundary = new ActorOutputBoundary(shell, out.toString) logics.add(boundary) boundary.stageId = logics.size() - 1 + boundary.attributes = logic.attributes val connection = outConn() boundary.portToConn(boundary.in.id) = connection diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 65f501d305..37e6086ca9 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -305,9 +305,6 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * INTERNAL API - * - * Input handlers followed by output handlers, use `inHandler(id)` and `outHandler(id)` to access the respective - * handlers. */ private[stream] var attributes: Attributes = Attributes.none @@ -320,8 +317,10 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * INTERNAL API + * + * Input handlers followed by output handlers, use `inHandler(id)` and `outHandler(id)` to access the respective + * handlers. */ - // Using common array to reduce overhead for small port counts private[stream] val handlers = new Array[Any](inCount + outCount) /** @@ -515,7 +514,26 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Requests to stop receiving events from a given input port. Cancelling clears any ungrabbed elements from the port. */ - final protected def cancel[T](in: Inlet[T], cause: Throwable): Unit = interpreter.cancel(conn(in), cause) + final protected def cancel[T](in: Inlet[T], cause: Throwable): Unit = cancel(conn(in), cause) + + private def cancel[T](connection: Connection, cause: Throwable): Unit = + attributes.mandatoryAttribute[Attributes.CancellationStrategy].strategy match { + case Attributes.CancellationStrategy.AfterDelay(delay, _) => + // since the port is not actually cancelled, we install a handler to ignore upcoming elements + connection.inHandler = new InHandler { + // ignore pushs now, since the stage wanted it cancelled already + override def onPush(): Unit = () + // do not ignore termination signals + } + val callback = getAsyncCallback[(Connection, Throwable)] { + case (connection, cause) => doCancel(connection, cause) + } + materializer.scheduleOnce(delay, () => callback.invoke((connection, cause))) + case _ => + doCancel(connection, cause) + } + + private def doCancel[T](connection: Connection, cause: Throwable): Unit = interpreter.cancel(connection, cause) /** * Once the callback [[InHandler.onPush]] for an input port has been invoked, the element that has been pushed @@ -662,7 +680,25 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Automatically invokes [[cancel]] or [[complete]] on all the input or output ports that have been called, * then marks the stage as stopped. */ - final def cancelStage(cause: Throwable): Unit = internalCompleteStage(cause, OptionVal.None) + final def cancelStage(cause: Throwable): Unit = + internalCancelStage(cause, attributes.mandatoryAttribute[Attributes.CancellationStrategy].strategy) + + private def internalCancelStage(cause: Throwable, strategy: Attributes.CancellationStrategy.Strategy): Unit = { + import Attributes.CancellationStrategy._ + import SubscriptionWithCancelException._ + strategy match { + case CompleteStage => internalCompleteStage(cause, OptionVal.None) + case FailStage => internalCompleteStage(cause, OptionVal.Some(cause)) + case PropagateFailure => + cause match { + case NoMoreElementsNeeded | StageWasCompleted => internalCompleteStage(cause, OptionVal.None) + case _ => internalCompleteStage(cause, OptionVal.Some(cause)) + } + case AfterDelay(_, andThen) => + // delay handled at the stage that sends the delay. See `def cancel(in, cause)`. + internalCancelStage(cause, andThen) + } + } /** * Automatically invokes [[cancel]] or [[fail]] on all the input or output ports that have been called, @@ -678,7 +714,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: var i = 0 while (i < portToConn.length) { if (i < inCount) - interpreter.cancel(portToConn(i), cancelCause) + cancel(portToConn(i), cancelCause) // call through GraphStage.cancel to apply delay if applicable else if (optionalFailureCause.isDefined) interpreter.fail(portToConn(i), optionalFailureCause.get) else