From 16033eaf5e0f2b49f3ca57b57bae25b1856da191 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Fri, 16 Aug 2019 10:53:14 +0200 Subject: [PATCH] Propagate stream cancellation causes (#27266) * base functionality * fix-restart-flow * Fix subSource / subSink cancellation handling * GraphStage-fix * Fix ambiguity between complete and cancellation (for isAvailable / grab) * rename lastCancellationCause * add mima * fix cancellation cause propagation in OutputBoundary * Fix cancellation cause propagation in SubSink * Add cancellation cause logging to Flow.log * add more comments about GraphStage portState internals * Add some assertions in onDownstreamFinish to prevent wrong usage * Also deprecate onDownstreamFinish() so that no one calls the wrong one accidentally * add SubSinkInlet.cancel(cause) * Propagate causes in two other places * Suggest to use `cancel(in, cause)` but don't deprecate old one --- .../akka/stream/testkit/TestGraphStage.scala | 4 +- .../stream/impl/GraphStageLogicSpec.scala | 2 +- .../mima-filters/2.5.x.backwards.excludes | 8 +- .../main/scala/akka/stream/KillSwitch.scala | 4 +- .../SubscriptionWithCancelException.scala | 24 ++++ .../impl/ReactiveStreamsCompliance.scala | 9 +- .../main/scala/akka/stream/impl/Sinks.scala | 7 +- .../scala/akka/stream/impl/StreamLayout.scala | 10 +- .../main/scala/akka/stream/impl/Timers.scala | 2 +- .../impl/fusing/ActorGraphInterpreter.scala | 53 +++++---- .../stream/impl/fusing/GraphInterpreter.scala | 22 +++- .../akka/stream/impl/fusing/GraphStages.scala | 11 +- .../scala/akka/stream/impl/fusing/Ops.scala | 15 ++- .../stream/impl/fusing/StreamOfStreams.scala | 16 +-- .../scala/akka/stream/scaladsl/Graph.scala | 4 +- .../akka/stream/scaladsl/RestartFlow.scala | 4 +- .../scala/akka/stream/stage/GraphStage.scala | 111 +++++++++++++----- 17 files changed, 210 insertions(+), 96 deletions(-) create mode 100644 akka-stream/src/main/scala/akka/stream/SubscriptionWithCancelException.scala diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala index 785b8badc4..1fd1ea2ba3 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala @@ -133,9 +133,9 @@ private[testkit] class TestSourceStage[T, M]( throw ex } } - override def onDownstreamFinish(): Unit = { + override def onDownstreamFinish(cause: Throwable): Unit = { try { - outHandler.onDownstreamFinish() + outHandler.onDownstreamFinish(cause) probe.ref ! GraphStageMessages.DownstreamFinish } catch { case NonFatal(ex) => diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala index f617871138..b7a590cdac 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala @@ -227,7 +227,7 @@ class GraphStageLogicSpec extends StreamSpec with GraphInterpreterSpecKit with S // note: a bit dangerous assumptions about connection and logic positions here // if anything around creating the logics and connections in the builder changes this may fail interpreter.complete(interpreter.connections(0)) - interpreter.cancel(interpreter.connections(1)) + interpreter.cancel(interpreter.connections(1), SubscriptionWithCancelException.NoMoreElementsNeeded) interpreter.execute(2) expectMsg("postStop2") diff --git a/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes index e538b33bf0..f4791643a3 100644 --- a/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes +++ b/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes @@ -146,7 +146,6 @@ ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSub ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSubscriber$") ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSink") - # #19980 subscription timeouts for streams ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.impl.ActorProcessorImpl.subTimeoutHandling") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.FanoutOutputs.this") @@ -157,3 +156,10 @@ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.Reducer # Protobuf 3 ProblemFilters.exclude[Problem]("akka.stream.StreamRefMessages*") + +# #27266 changes to streams internals +ProblemFilters.exclude[Problem]("akka.stream.impl.*") + +# added private[this] field to public class, shouldn't have more impact than a potential naming clash +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.stage.OutHandler.akka$stream$stage$OutHandler$$_lastCancellationCause") +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.stage.OutHandler.akka$stream$stage$OutHandler$$_lastCancellationCause_=") diff --git a/akka-stream/src/main/scala/akka/stream/KillSwitch.scala b/akka-stream/src/main/scala/akka/stream/KillSwitch.scala index df8ae76b75..b4d0169cb7 100644 --- a/akka-stream/src/main/scala/akka/stream/KillSwitch.scala +++ b/akka-stream/src/main/scala/akka/stream/KillSwitch.scala @@ -121,11 +121,11 @@ object KillSwitches { }) setHandler(shape.out1, new OutHandler { override def onPull(): Unit = pull(shape.in1) - override def onDownstreamFinish(): Unit = cancel(shape.in1) + override def onDownstreamFinish(cause: Throwable): Unit = cancel(shape.in1, cause) }) setHandler(shape.out2, new OutHandler { override def onPull(): Unit = pull(shape.in2) - override def onDownstreamFinish(): Unit = cancel(shape.in2) + override def onDownstreamFinish(cause: Throwable): Unit = cancel(shape.in2, cause) }) } diff --git a/akka-stream/src/main/scala/akka/stream/SubscriptionWithCancelException.scala b/akka-stream/src/main/scala/akka/stream/SubscriptionWithCancelException.scala new file mode 100644 index 0000000000..29500561b7 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/SubscriptionWithCancelException.scala @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2019 Lightbend Inc. + */ + +package akka.stream + +import org.reactivestreams.Subscription + +import scala.util.control.NoStackTrace + +/** + * Extension of Subscription that allows to pass a cause when a subscription is cancelled. + * + * Subscribers can check for this trait and use its `cancel(cause)` method instead of the regular + * cancel method to pass a cancellation cause. + */ +trait SubscriptionWithCancelException extends Subscription { + final override def cancel() = cancel(SubscriptionWithCancelException.NoMoreElementsNeeded) + def cancel(cause: Throwable): Unit +} +object SubscriptionWithCancelException { + case object NoMoreElementsNeeded extends RuntimeException with NoStackTrace + case object StageWasCompleted extends RuntimeException with NoStackTrace +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala b/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala index 7b667e5062..6fa6a38d1f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala @@ -5,6 +5,7 @@ package akka.stream.impl import akka.annotation.InternalApi +import akka.stream.SubscriptionWithCancelException import scala.util.control.NonFatal import org.reactivestreams.{ Subscriber, Subscription } @@ -125,11 +126,13 @@ import org.reactivestreams.{ Subscriber, Subscription } } } - final def tryCancel(subscription: Subscription): Unit = { + final def tryCancel(subscription: Subscription, cause: Throwable): Unit = { if (subscription eq null) throw new IllegalStateException("Subscription must be not null on cancel() call, rule 1.3") - try subscription.cancel() - catch { + try subscription match { + case s: SubscriptionWithCancelException => s.cancel(cause) + case s => s.cancel() + } catch { case NonFatal(t) => throw new SignalThrewException("It is illegal to throw exceptions from cancel(), rule 3.15", t) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index d6a9f292d7..ff4c5d3273 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -674,10 +674,9 @@ import org.reactivestreams.Subscriber } } } - override def onDownstreamFinish(): Unit = { - if (!isClosed(in)) { - cancel(in) - } + + override def onDownstreamFinish(cause: Throwable): Unit = { + if (!isClosed(in)) cancel(in, cause) maybeCompleteStage() } }) diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index e71c35a91f..078d461cf6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -176,10 +176,10 @@ import scala.util.control.NonFatal case _ => pub.subscribe(subscriber.asInstanceOf[Subscriber[Any]]) } } - case _ => + case state @ _ => if (VirtualProcessor.Debug) println(s"VirtualPublisher#$hashCode(_).onSubscribe.rec($s) spec violation") // spec violation - tryCancel(s) + tryCancel(s, new IllegalStateException(s"VirtualProcessor in wrong state [$state]. Spec violation")) } } @@ -223,7 +223,7 @@ import scala.util.control.NonFatal set(Inert) case Inert => - tryCancel(subscription) + tryCancel(subscription, new IllegalStateException("VirtualProcessor was already subscribed to.")) case other => throw new IllegalStateException( @@ -234,7 +234,7 @@ import scala.util.control.NonFatal } catch { case NonFatal(ex) => set(Inert) - tryCancel(subscription) + tryCancel(subscription, ex) tryOnError(establishing.subscriber, ex) } } @@ -397,7 +397,7 @@ import scala.util.control.NonFatal if (n < 1) { if (VirtualProcessor.Debug) println(s"VirtualPublisher#${VirtualProcessor.this.hashCode}.WrappedSubscription($real).request($n)") - tryCancel(real) + tryCancel(real, new IllegalArgumentException(s"Demand must not be < 1 but was $n")) VirtualProcessor.this.getAndSet(Inert) match { case Both(subscriber) => rejectDueToNonPositiveDemand(subscriber) case est: Establishing => rejectDueToNonPositiveDemand(est.subscriber) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Timers.scala b/akka-stream/src/main/scala/akka/stream/impl/Timers.scala index bd20c2e49b..3bb9ee670a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Timers.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Timers.scala @@ -182,7 +182,7 @@ import scala.concurrent.duration.{ Duration, FiniteDuration } override def onPull(): Unit = pull(in) override def onUpstreamFinish(): Unit = complete(out) - override def onDownstreamFinish(): Unit = cancel(in) + override def onDownstreamFinish(cause: Throwable): Unit = cancel(in, cause) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 393fb884b3..c44348e9ba 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -111,7 +111,7 @@ import scala.util.control.NonFatal private var inputBufferElements = 0 private var nextInputElementCursor = 0 private var upstreamCompleted = false - private var downstreamCanceled = false + private var downstreamCanceled: Option[Throwable] = None private val IndexMask = size - 1 private def requestBatchSize = math.max(1, inputBuffer.length / 2) @@ -168,11 +168,11 @@ import scala.util.control.NonFatal inputBufferElements = 0 } - def cancel(): Unit = { - downstreamCanceled = true + def cancel(cause: Throwable): Unit = { + downstreamCanceled = Some(cause) if (!upstreamCompleted) { upstreamCompleted = true - if (upstream ne null) tryCancel(upstream) + if (upstream ne null) tryCancel(upstream, cause) clear() } } @@ -188,7 +188,7 @@ import scala.util.control.NonFatal } def onError(e: Throwable): Unit = - if (!upstreamCompleted || !downstreamCanceled) { + if (!upstreamCompleted || downstreamCanceled.isEmpty) { upstreamCompleted = true clear() fail(out, e) @@ -197,7 +197,7 @@ import scala.util.control.NonFatal // Call this when an error happens that does not come from the usual onError channel // (exceptions while calling RS interfaces, abrupt termination etc) def onInternalError(e: Throwable): Unit = { - if (!(upstreamCompleted || downstreamCanceled) && (upstream ne null)) { + if (!(upstreamCompleted || downstreamCanceled.isDefined) && (upstream ne null)) { upstream.cancel() } if (!isClosed(out)) onError(e) @@ -212,12 +212,13 @@ import scala.util.control.NonFatal def onSubscribe(subscription: Subscription): Unit = { ReactiveStreamsCompliance.requireNonNullSubscription(subscription) if (upstreamCompleted) { - tryCancel(subscription) - } else if (downstreamCanceled) { + // onComplete or onError has been called before OnSubscribe + tryCancel(subscription, SubscriptionWithCancelException.NoMoreElementsNeeded) + } else if (downstreamCanceled.isDefined) { upstreamCompleted = true - tryCancel(subscription) + tryCancel(subscription, downstreamCanceled.get) } else if (upstream != null) { // reactive streams spec 2.5 - tryCancel(subscription) + tryCancel(subscription, new IllegalStateException("Publisher can only be subscribed once.")) } else { upstream = subscription // Prefetch @@ -243,8 +244,8 @@ import scala.util.control.NonFatal } } - override def onDownstreamFinish(): Unit = - try cancel() + override def onDownstreamFinish(cause: Throwable): Unit = + try cancel(cause) catch { case s: SpecViolation => shell.tryAbort(s) } @@ -270,11 +271,12 @@ import scala.util.control.NonFatal override def shell: GraphInterpreterShell = boundary.shell override def logic: GraphStageLogic = boundary } - final case class Cancel(boundary: ActorOutputBoundary) extends SimpleBoundaryEvent { + final case class Cancel(boundary: ActorOutputBoundary, cause: Throwable) extends SimpleBoundaryEvent { override def execute(): Unit = { if (GraphInterpreter.Debug) - println(s"${boundary.shell.interpreter.Name} cancel port=${boundary.internalPortName}") - boundary.cancel() + println( + s"${boundary.shell.interpreter.Name} cancel port=${boundary.internalPortName} cause=${cause.getMessage}") + boundary.cancel(cause) } override def shell: GraphInterpreterShell = boundary.shell @@ -360,7 +362,9 @@ import scala.util.control.NonFatal private var downstreamDemand: Long = 0L // This flag is only used if complete/fail is called externally since this op turns into a Finished one inside the // interpreter (i.e. inside this op this flag has no effects since if it is completed the op will not be invoked) - private var downstreamCompleted = false + private[this] var downstreamCompletionCause: Option[Throwable] = None + def downstreamCompleted: Boolean = downstreamCompletionCause.isDefined + // when upstream failed before we got the exposed publisher private var upstreamCompleted: Boolean = false @@ -392,7 +396,7 @@ import scala.util.control.NonFatal override def onPush(): Unit = { try { onNext(grab(in)) - if (downstreamCompleted) cancel(in) + if (downstreamCompleted) cancel(in, downstreamCompletionCause.get) else if (downstreamDemand > 0) pull(in) } catch { case s: SpecViolation => shell.tryAbort(s) @@ -415,9 +419,10 @@ import scala.util.control.NonFatal publisher.takePendingSubscribers().foreach { sub => if (subscriber eq null) { subscriber = sub - val subscription = new Subscription { + val subscription = new Subscription with SubscriptionWithCancelException { override def request(elements: Long): Unit = actor ! RequestMore(ActorOutputBoundary.this, elements) - override def cancel(): Unit = actor ! Cancel(ActorOutputBoundary.this) + override def cancel(cause: Throwable): Unit = actor ! Cancel(ActorOutputBoundary.this, cause) + override def toString = s"BoundarySubscription[$actor, $internalPortName]" } @@ -430,7 +435,7 @@ import scala.util.control.NonFatal def requestMore(elements: Long): Unit = { if (elements < 1) { - cancel(in) + cancel(in, ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException) fail(ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException) } else { downstreamDemand += elements @@ -440,11 +445,11 @@ import scala.util.control.NonFatal } } - def cancel(): Unit = { - downstreamCompleted = true + def cancel(cause: Throwable): Unit = { + downstreamCompletionCause = Some(cause) subscriber = null publisher.shutdown(Some(new ActorPublisher.NormalShutdownException)) - cancel(in) + cancel(in, cause) } override def toString: String = @@ -655,7 +660,7 @@ import scala.util.control.NonFatal // Will only have an effect if the above call to the interpreter failed to emit a proper failure to the downstream // otherwise this will have no effect outputs.foreach(_.fail(reason)) - inputs.foreach(_.cancel()) + inputs.foreach(_.cancel(reason)) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index 75342a7a35..00a7078015 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -54,8 +54,13 @@ import akka.stream.snapshot._ * but there is no more element to grab. */ case object Empty + + /** Marker class that indicates that a port was failed with a given cause and a potential outstanding element */ final case class Failed(ex: Throwable, previousElem: Any) + /** Marker class that indicates that a port was cancelled with a given cause */ + final case class Cancelled(cause: Throwable) + abstract class UpstreamBoundaryStageLogic[T] extends GraphStageLogic(inCount = 0, outCount = 1) { def out: Outlet[T] } @@ -85,7 +90,16 @@ import akka.stream.snapshot._ var outOwner: GraphStageLogic, var inHandler: InHandler, var outHandler: OutHandler) { + + /** See [[GraphInterpreter]] about possible states */ var portState: Int = InReady + + /** + * Can either be + * * an in-flight element + * * a failure (with an optional in-flight element), if elem.isInstanceOf[Failed] + * * a cancellation cause, if elem.isInstanceOf[Cancelled] + */ var slot: Any = Empty } @@ -493,7 +507,9 @@ import akka.stream.snapshot._ s"$Name CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${connection.outHandler}) [${outLogicName(connection)}]") connection.portState |= OutClosed completeConnection(connection.outOwner.stageId) - connection.outHandler.onDownstreamFinish() + val cause = connection.slot.asInstanceOf[Cancelled].cause + connection.slot = Empty + connection.outHandler.onDownstreamFinish(cause) } else if ((code & (OutClosed | InClosed)) == OutClosed) { // COMPLETIONS @@ -637,12 +653,12 @@ import akka.stream.snapshot._ } @InternalStableApi - private[stream] def cancel(connection: Connection): Unit = { + private[stream] def cancel(connection: Connection, cause: Throwable): Unit = { val currentState = connection.portState if (Debug) println(s"$Name cancel($connection) [$currentState]") connection.portState = currentState | InClosed if ((currentState & OutClosed) == 0) { - connection.slot = Empty + connection.slot = Cancelled(cause) if ((currentState & (Pulling | Pushing | InClosed)) == 0) enqueue(connection) else if (chasedPull eq connection) { // Abort chasing so Cancel is not lost (chasing does NOT decode the event but assumes it to be a PULL diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index b3ad7a2fb6..014d8653b0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -191,8 +191,8 @@ import scala.concurrent.{ Future, Promise } def onPull(): Unit = pull(in) - override def onDownstreamFinish(): Unit = { - super.onDownstreamFinish() + override def onDownstreamFinish(cause: Throwable): Unit = { + super.onDownstreamFinish(cause) monitor.set(Finished) } @@ -307,16 +307,17 @@ import scala.concurrent.{ Future, Promise } new OutHandler { def onPull(): Unit = {} - override def onDownstreamFinish(): Unit = { + override def onDownstreamFinish(cause: Throwable): Unit = { if (!materialized.isCompleted) { // we used to try to materialize the "inner" source here just to get // the materialized value, but that is not safe and may cause the graph shell // to leak/stay alive after the stage completes - materialized.tryFailure(new StreamDetachedException("Stream cancelled before Source Future completed")) + materialized.tryFailure( + new StreamDetachedException("Stream cancelled before Source Future completed").initCause(cause)) } - super.onDownstreamFinish() + super.onDownstreamFinish(cause) } }) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index a1bcea8b12..4d9a9e9d4d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -1505,11 +1505,16 @@ private[stream] object Collect { super.onUpstreamFinish() } - override def onDownstreamFinish(): Unit = { + override def onDownstreamFinish(cause: Throwable): Unit = { if (isEnabled(logLevels.onFinish)) - log.log(logLevels.onFinish, "[{}] Downstream finished.", name) + log.log( + logLevels.onFinish, + "[{}] Downstream finished, cause: {}: {}", + name, + Logging.simpleName(cause.getClass), + cause.getMessage) - super.onDownstreamFinish() + super.onDownstreamFinish(cause: Throwable) } private def isEnabled(l: LogLevel): Boolean = l.asInt != OffInt @@ -2150,9 +2155,9 @@ private[stream] object Collect { super.onUpstreamFailure(ex) } - override def onDownstreamFinish(): Unit = { + override def onDownstreamFinish(cause: Throwable): Unit = { matPromise.success(None) - super.onDownstreamFinish() + super.onDownstreamFinish(cause) } override def onPull(): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala index aca0db4726..94b79400f9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala @@ -661,7 +661,8 @@ import akka.stream.impl.fusing.GraphStages.SingleSource case object RequestOneScheduledBeforeMaterialization extends CommandScheduledBeforeMaterialization(RequestOne) /** A Cancel command was scheduled before materialization */ - case object CancelScheduledBeforeMaterialization extends CommandScheduledBeforeMaterialization(Cancel) + case class CancelScheduledBeforeMaterialization(cause: Throwable) + extends CommandScheduledBeforeMaterialization(Cancel(cause)) /** Steady state: sink has been materialized, commands can be delivered through the callback */ // Represented in unwrapped form as AsyncCallback[Command] directly to prevent a level of indirection @@ -669,7 +670,7 @@ import akka.stream.impl.fusing.GraphStages.SingleSource sealed trait Command case object RequestOne extends Command - case object Cancel extends Command + case class Cancel(cause: Throwable) extends Command } /** @@ -687,7 +688,8 @@ import akka.stream.impl.fusing.GraphStages.SingleSource private val status = new AtomicReference[ /* State */ AnyRef](Uninitialized) def pullSubstream(): Unit = dispatchCommand(RequestOneScheduledBeforeMaterialization) - def cancelSubstream(): Unit = dispatchCommand(CancelScheduledBeforeMaterialization) + def cancelSubstream(): Unit = cancelSubstream(SubscriptionWithCancelException.NoMoreElementsNeeded) + def cancelSubstream(cause: Throwable): Unit = dispatchCommand(CancelScheduledBeforeMaterialization(cause)) @tailrec private def dispatchCommand(newState: CommandScheduledBeforeMaterialization): Unit = @@ -697,7 +699,7 @@ import akka.stream.impl.fusing.GraphStages.SingleSource if (!status.compareAndSet(Uninitialized, newState)) dispatchCommand(newState) // changed to materialized in the meantime - case RequestOneScheduledBeforeMaterialization if newState == CancelScheduledBeforeMaterialization => + case RequestOneScheduledBeforeMaterialization if newState.isInstanceOf[CancelScheduledBeforeMaterialization] => // cancellation is allowed to replace pull if (!status.compareAndSet(RequestOneScheduledBeforeMaterialization, newState)) dispatchCommand(RequestOneScheduledBeforeMaterialization) @@ -735,8 +737,8 @@ import akka.stream.impl.fusing.GraphStages.SingleSource override def preStart(): Unit = setCallback { - case RequestOne => tryPull(in) - case Cancel => completeStage() + case RequestOne => tryPull(in) + case Cancel(cause) => cancelStage(cause) } } @@ -807,7 +809,7 @@ import akka.stream.impl.fusing.GraphStages.SingleSource } override def onPull(): Unit = externalCallback.invoke(RequestOne) - override def onDownstreamFinish(): Unit = externalCallback.invoke(Cancel) + override def onDownstreamFinish(cause: Throwable): Unit = externalCallback.invoke(Cancel(cause)) } override def toString: String = name diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index e3f712bf9b..f64683c846 100755 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -15,12 +15,12 @@ import akka.stream.impl.fusing.GraphStages import akka.stream.scaladsl.Partition.PartitionOutOfBoundsException import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } import akka.util.ConstantFun + import scala.annotation.tailrec import scala.annotation.unchecked.uncheckedVariance import scala.collection.{ immutable, mutable } import scala.concurrent.Promise import scala.util.control.{ NoStackTrace, NonFatal } - import akka.stream.ActorAttributes.SupervisionStrategy /** @@ -359,7 +359,7 @@ final class MergePrioritized[T] private (val priorities: Seq[Int], val eagerComp override def onUpstreamFinish(): Unit = { if (eagerComplete) { - in.foreach(cancel) + in.foreach(cancel(_)) runningUpstreams = 0 if (!hasPending) completeStage() } else { diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/RestartFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/RestartFlow.scala index 699d9baefd..8e874eab59 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/RestartFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/RestartFlow.scala @@ -303,9 +303,9 @@ private abstract class RestartWithBackoffLogic[S <: Shape]( * Can either be a failure or a cancel in the wrapped state. * onlyOnFailures is thus racy so a delay to cancellation is added in the case of a flow. */ - override def onDownstreamFinish() = { + override def onDownstreamFinish(cause: Throwable) = { if (finishing || maxRestartsReached() || onlyOnFailures) { - cancel(in) + cancel(in, cause) } else { scheduleRestartTimer() } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 4852aa0f0c..bbb3528f47 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -6,6 +6,7 @@ package akka.stream.stage import java.util.concurrent.atomic.AtomicReference +import scala.deprecated import akka.actor._ import akka.annotation.InternalApi import akka.japi.function.{ Effect, Procedure } @@ -18,12 +19,13 @@ import akka.stream.scaladsl.GenericGraphWithChangedAttributes import akka.util.OptionVal import akka.util.unused import akka.{ Done, NotUsed } + import scala.annotation.tailrec import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ Await, Future, Promise } - import akka.stream.impl.StreamSupervisor +import com.github.ghik.silencer.silent /** * Scala API: A GraphStage represents a reusable graph stream processing operator. @@ -170,7 +172,7 @@ object GraphStageLogic { */ object IgnoreTerminateOutput extends OutHandler { override def onPull(): Unit = () - override def onDownstreamFinish(): Unit = () + override def onDownstreamFinish(cause: Throwable): Unit = () override def toString = "IgnoreTerminateOutput" } @@ -180,8 +182,8 @@ object GraphStageLogic { */ class ConditionalTerminateOutput(predicate: () => Boolean) extends OutHandler { override def onPull(): Unit = () - override def onDownstreamFinish(): Unit = - if (predicate()) GraphInterpreter.currentInterpreter.activeStage.completeStage() + override def onDownstreamFinish(cause: Throwable): Unit = + if (predicate()) GraphInterpreter.currentInterpreter.activeStage.cancelStage(cause) } private object DoNothing extends (() => Unit) { @@ -533,8 +535,16 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Requests to stop receiving events from a given input port. Cancelling clears any ungrabbed elements from the port. + * + * If cancellation is due to an error, use `cancel(in, cause)` instead to propagate that cause upstream. This overload + * is a shortcut for `cancel(in, SubscriptionWithCancelException.NoMoreElementsNeeded)` */ - final protected def cancel[T](in: Inlet[T]): Unit = interpreter.cancel(conn(in)) + final protected def cancel[T](in: Inlet[T]): Unit = cancel(in, SubscriptionWithCancelException.NoMoreElementsNeeded) + + /** + * Requests to stop receiving events from a given input port. Cancelling clears any ungrabbed elements from the port. + */ + final protected def cancel[T](in: Inlet[T], cause: Throwable): Unit = interpreter.cancel(conn(in), cause) /** * Once the callback [[InHandler.onPush()]] for an input port has been invoked, the element that has been pushed @@ -547,18 +557,27 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: val connection = conn(in) val elem = connection.slot - // Fast path - if ((connection.portState & (InReady | InFailed)) == InReady && (elem.asInstanceOf[AnyRef] ne Empty)) { + // Fast path for active connections + if ((connection.portState & (InReady | InFailed | InClosed)) == InReady && (elem.asInstanceOf[AnyRef] ne Empty)) { connection.slot = Empty elem.asInstanceOf[T] } else { - // Slow path + // Slow path for grabbing element from already failed or completed connections if (!isAvailable(in)) throw new IllegalArgumentException(s"Cannot get element from already empty input port ($in)") - val failed = connection.slot.asInstanceOf[Failed] - val elem = failed.previousElem.asInstanceOf[T] - connection.slot = Failed(failed.ex, Empty) - elem + + if ((connection.portState & (InReady | InFailed)) == (InReady | InFailed)) { + // failed + val failed = connection.slot.asInstanceOf[Failed] + val elem = failed.previousElem.asInstanceOf[T] + connection.slot = Failed(failed.ex, Empty) + elem + } else { + // completed + val elem = connection.slot.asInstanceOf[T] + connection.slot = Empty + elem + } } } @@ -577,18 +596,21 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: final protected def isAvailable[T](in: Inlet[T]): Boolean = { val connection = conn(in) - val normalArrived = (conn(in).portState & (InReady | InFailed)) == InReady + val normalArrived = (conn(in).portState & (InReady | InFailed | InClosed)) == InReady - // Fast path + // Fast path for active connection if (normalArrived) connection.slot.asInstanceOf[AnyRef] ne Empty else { - // Slow path on failure - if ((connection.portState & (InReady | InFailed)) == (InReady | InFailed)) { + // slow path on failure, closure, and cancellation + if ((connection.portState & (InReady | InClosed | InFailed)) == (InReady | InClosed)) connection.slot match { - case Failed(_, elem) => elem.asInstanceOf[AnyRef] ne Empty - case _ => false // This can only be Empty actually (if a cancel was concurrent with a failure) - } - } else false + case Empty | _ @(_: Cancelled) => false // cancelled (element is discarded when cancelled) + case _ => true // completed but element still there to grab + } else if ((connection.portState & (InReady | InFailed)) == (InReady | InFailed)) + connection.slot match { + case Failed(_, elem) => elem.asInstanceOf[AnyRef] ne Empty // failed but element still there to grab + case _ => false + } else false } } @@ -655,11 +677,21 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Automatically invokes [[cancel()]] or [[complete()]] on all the input or output ports that have been called, * then marks the operator as stopped. */ - final def completeStage(): Unit = { + final def completeStage(): Unit = cancelStage(SubscriptionWithCancelException.StageWasCompleted) + + /** + * Automatically invokes [[cancel()]] or [[complete()]] on all the input or output ports that have been called, + * then marks the stage as stopped. + */ + final def cancelStage(cause: Throwable): Unit = { + // TODO: It's debatable if completing the stage if one output is cancelled is the right way to do things. + // At least optionally it might be more reasonable to fail the stage with the given cause. That + // would mean that all other *outputs* are failed, i.e. it would only concern stages with more that one + // output anyway. var i = 0 while (i < portToConn.length) { if (i < inCount) - interpreter.cancel(portToConn(i)) + interpreter.cancel(portToConn(i), cause) else handlers(i) match { case e: Emitting[_] => e.addFollowUp(new EmittingCompletion(e.out, e.previous)) @@ -678,7 +710,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: var i = 0 while (i < portToConn.length) { if (i < inCount) - interpreter.cancel(portToConn(i)) + interpreter.cancel(portToConn(i), ex) else interpreter.fail(portToConn(i), ex) i += 1 @@ -1004,7 +1036,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: ret } - override def onDownstreamFinish(): Unit = previous.onDownstreamFinish() + override def onDownstreamFinish(cause: Throwable): Unit = previous.onDownstreamFinish(cause) } private class EmittingSingle[T](_out: Outlet[T], elem: T, _previous: OutHandler, _andThen: () => Unit) @@ -1379,9 +1411,10 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: _sink.pullSubstream() } - def cancel(): Unit = { + def cancel(): Unit = cancel(SubscriptionWithCancelException.NoMoreElementsNeeded) + def cancel(cause: Throwable): Unit = { closed = true - _sink.cancelSubstream() + _sink.cancelSubstream(cause) } override def toString = s"SubSinkInlet($name)" @@ -1410,11 +1443,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: available = true handler.onPull() } - case SubSink.Cancel => + case SubSink.Cancel(cause) => if (!closed) { available = false closed = true - handler.onDownstreamFinish() + handler.onDownstreamFinish(cause) } } @@ -1772,14 +1805,34 @@ trait OutHandler { @throws(classOf[Exception]) def onPull(): Unit + // Hack to make sure that old `onDownstreamFinish` can be called without losing the cause in the default implementation + private[this] var _lastCancellationCause: Throwable = _ + /** * Called when the output port will no longer accept any new elements. After this callback no other callbacks will * be called for this port. */ @throws(classOf[Exception]) + // FIXME: add this after fixing our own usages, https://github.com/akka/akka/issues/27472 + // @deprecatedOverriding("Override `def onDownstreamFinish(cause: Throwable)`, instead.", since = "2.6.0") // warns when overriding + @deprecated("Call onDownstreamFinish with a cancellation cause.", since = "2.6.0") // warns when calling def onDownstreamFinish(): Unit = { - GraphInterpreter.currentInterpreter.activeStage.completeStage() + require(_lastCancellationCause ne null, "onDownstreamFinish() must not be called without a cancellation cause") + GraphInterpreter.currentInterpreter.activeStage.cancelStage(_lastCancellationCause) } + + /** + * Called when the output port will no longer accept any new elements. After this callback no other callbacks will + * be called for this port. + */ + @throws(classOf[Exception]) + def onDownstreamFinish(cause: Throwable): Unit = + try { + require(cause ne null, "Cancellation cause must not be null") + require(_lastCancellationCause eq null, "onDownstreamFinish(cause) must not be called recursively") + _lastCancellationCause = cause + (onDownstreamFinish(): @silent("deprecated")) // if not overridden, call old deprecated variant + } finally _lastCancellationCause = null } /**