diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala index e0fa6e21ee..932eaf410b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala @@ -110,6 +110,19 @@ class FlowConcatAllSpec extends AkkaSpec { upstream.expectCancellation() } + "pass along early cancellation" in assertAllStagesStopped { + val up = StreamTestKit.PublisherProbe[Source[Int, _]]() + val down = StreamTestKit.SubscriberProbe[Int]() + + val flowSubscriber = Source.subscriber[Source[Int, _]].flatten(FlattenStrategy.concat).to(Sink(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + up.subscribe(flowSubscriber) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGroupBySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGroupBySpec.scala index bcf5becfb4..207852e6fb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGroupBySpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGroupBySpec.scala @@ -267,6 +267,19 @@ class FlowGroupBySpec extends AkkaSpec { substreamPuppet2.expectComplete() } + "pass along early cancellation" in assertAllStagesStopped { + val up = StreamTestKit.PublisherProbe[Int]() + val down = StreamTestKit.SubscriberProbe[(Int, Source[Int, Unit])]() + + val flowSubscriber = Source.subscriber[Int].groupBy(_ % 2).to(Sink(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + up.subscribe(flowSubscriber) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala index 63a5cb5e06..fff43f073b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala @@ -181,6 +181,19 @@ class FlowPrefixAndTailSpec extends AkkaSpec { } + "pass along early cancellation" in assertAllStagesStopped { + val up = StreamTestKit.PublisherProbe[Int]() + val down = StreamTestKit.SubscriberProbe[(immutable.Seq[Int], Source[Int, _])]() + + val flowSubscriber = Source.subscriber[Int].prefixAndTail(1).to(Sink(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + up.subscribe(flowSubscriber) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala index d2f87a10f0..d9f561bd24 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala @@ -192,6 +192,19 @@ class FlowSplitWhenSpec extends AkkaSpec { substreamPuppet2.expectComplete() } + "pass along early cancellation" in assertAllStagesStopped { + val up = StreamTestKit.PublisherProbe[Int]() + val down = StreamTestKit.SubscriberProbe[Source[Int, Unit]]() + + val flowSubscriber = Source.subscriber[Int].splitWhen(_ % 3 == 0).to(Sink(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + up.subscribe(flowSubscriber) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala index 2981f5e578..d1804f2b2e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala @@ -394,6 +394,30 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug s2.expectNext(1, 2, 3) s2.expectComplete() } + + "handle early cancelation" in assertAllStagesStopped { + val onDownstreamFinishProbe = TestProbe() + val down = StreamTestKit.SubscriberProbe[Int]() + val s = Source.subscriber[Int]. + transform(() ⇒ new PushStage[Int, Int] { + override def onPush(elem: Int, ctx: Context[Int]) = + ctx.push(elem) + override def onDownstreamFinish(ctx: Context[Int]): TerminationDirective = { + onDownstreamFinishProbe.ref ! "onDownstreamFinish" + ctx.finish() + } + }). + to(Sink(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + onDownstreamFinishProbe.expectMsg("onDownstreamFinish") + + val up = StreamTestKit.PublisherProbe[Int] + up.subscribe(s) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala index fa6769a129..cb240404de 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala @@ -146,8 +146,8 @@ class GraphBroadcastSpec extends AkkaSpec { FlowGraph.closed() { implicit b ⇒ val bcast = b.add(Broadcast[Int](2)) Source(List(1, 2, 3)) ~> bcast.in - bcast.out(0) ~> Flow[Int] ~> Sink(c1) - bcast.out(1) ~> Flow[Int] ~> Sink(c2) + bcast.out(0) ~> Flow[Int].named("identity-a") ~> Sink(c1) + bcast.out(1) ~> Flow[Int].named("identity-b") ~> Sink(c2) }.run() val sub1 = c1.expectSubscription() @@ -189,6 +189,31 @@ class GraphBroadcastSpec extends AkkaSpec { bsub.expectCancellation() } + "pass along early cancellation" in assertAllStagesStopped { + val c1 = StreamTestKit.SubscriberProbe[Int]() + val c2 = StreamTestKit.SubscriberProbe[Int]() + + val sink = Sink() { implicit b ⇒ + val bcast = b.add(Broadcast[Int](2)) + bcast.out(0) ~> Sink(c1) + bcast.out(1) ~> Sink(c2) + bcast.in + } + + val s = Source.subscriber[Int].to(sink).run() + + val up = StreamTestKit.PublisherProbe[Int]() + + val downsub1 = c1.expectSubscription() + val downsub2 = c2.expectSubscription() + downsub1.cancel() + downsub2.cancel() + + up.subscribe(s) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala index d3329f9358..1ded15508a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala @@ -144,6 +144,32 @@ class GraphMergeSpec extends TwoStreamsSetup { pending } + "pass along early cancellation" in assertAllStagesStopped { + val up1 = StreamTestKit.PublisherProbe[Int] + val up2 = StreamTestKit.PublisherProbe[Int] + val down = StreamTestKit.SubscriberProbe[Int]() + + val src1 = Source.subscriber[Int] + val src2 = Source.subscriber[Int] + + val (graphSubscriber1, graphSubscriber2) = FlowGraph.closed(src1, src2)((_, _)) { implicit b ⇒ + (s1, s2) ⇒ + val merge = b.add(Merge[Int](2)) + s1.outlet ~> merge.in(0) + s2.outlet ~> merge.in(1) + merge.out ~> Sink(down) + }.run() + + val downstream = down.expectSubscription() + downstream.cancel() + up1.subscribe(graphSubscriber1) + up2.subscribe(graphSubscriber2) + val upsub1 = up1.expectSubscription() + upsub1.expectCancellation() + val upsub2 = up2.expectSubscription() + upsub2.expectCancellation() + } + } } diff --git a/akka-stream/src/main/boilerplate/akka/stream/impl/ZipWith.scala.template b/akka-stream/src/main/boilerplate/akka/stream/impl/ZipWith.scala.template index ae5ba58a89..b8fc479178 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/impl/ZipWith.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/impl/ZipWith.scala.template @@ -14,7 +14,7 @@ private[akka] final class Zip1With(_settings: ActorFlowMaterializerSettings, f: inputBunch.markAllInputs() - nextPhase(TransferPhase(inputBunch.AllOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + initialPhase(inputCount, TransferPhase(inputBunch.AllOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ val elem##0 = inputBunch.dequeue(##0) [2..#val elem0 = inputBunch.dequeue(0)# ] diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index 87f5051ea9..34ed5dee75 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -206,19 +206,21 @@ private[akka] case class ActorFlowMaterializerImpl( actorOf(props, context.stageName, dispatcher) } - private[akka] def actorOf(props: Props, name: String, dispatcher: String): ActorRef = supervisor match { - case ref: LocalActorRef ⇒ - ref.underlying.attachChild(props.withDispatcher(dispatcher), name, systemService = false) - case ref: RepointableActorRef ⇒ - if (ref.isStarted) - ref.underlying.asInstanceOf[ActorCell].attachChild(props.withDispatcher(dispatcher), name, systemService = false) - else { - implicit val timeout = ref.system.settings.CreationTimeout - val f = (supervisor ? StreamSupervisor.Materialize(props.withDispatcher(dispatcher), name)).mapTo[ActorRef] - Await.result(f, timeout.duration) - } - case unknown ⇒ - throw new IllegalStateException(s"Stream supervisor must be a local actor, was [${unknown.getClass.getName}]") + private[akka] def actorOf(props: Props, name: String, dispatcher: String): ActorRef = { + supervisor match { + case ref: LocalActorRef ⇒ + ref.underlying.attachChild(props.withDispatcher(dispatcher), name, systemService = false) + case ref: RepointableActorRef ⇒ + if (ref.isStarted) + ref.underlying.asInstanceOf[ActorCell].attachChild(props.withDispatcher(dispatcher), name, systemService = false) + else { + implicit val timeout = ref.system.settings.CreationTimeout + val f = (supervisor ? StreamSupervisor.Materialize(props.withDispatcher(dispatcher), name)).mapTo[ActorRef] + Await.result(f, timeout.duration) + } + case unknown ⇒ + throw new IllegalStateException(s"Stream supervisor must be a local actor, was [${unknown.getClass.getName}]") + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala index 2de23fb189..acc4158d88 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala @@ -123,6 +123,7 @@ private[akka] abstract class BatchingInputBuffer(val size: Int, val pump: Pump) upstream.request(inputBuffer.length) subreceive.become(upstreamRunning) } + pump.gotUpstreamSubscription() } protected def onError(e: Throwable): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala index b31480ae88..b397659bd9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala @@ -37,7 +37,7 @@ private[akka] class ConcatAllImpl(materializer: ActorFlowMaterializer) else primaryOutputs.enqueueOutputElement(substream.dequeueInputElement()) } - nextPhase(takeNextSubstream) + initialPhase(1, takeNextSubstream) override def invalidateSubstreamInput(substream: SubstreamKey, e: Throwable): Unit = fail(e) diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala index 2b1bbe348c..1e3cae43d7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala @@ -274,7 +274,7 @@ private[akka] object FairMerge { private[akka] final class FairMerge(_settings: ActorFlowMaterializerSettings, _inputPorts: Int) extends FanIn(_settings, _inputPorts) { inputBunch.markAllInputs() - nextPhase(TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + initialPhase(inputCount, TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ val elem = inputBunch.dequeueAndYield() primaryOutputs.enqueueOutputElement(elem) }) @@ -299,7 +299,7 @@ private[akka] final class UnfairMerge(_settings: ActorFlowMaterializerSettings, val preferred: Int) extends FanIn(_settings, _inputPorts) { inputBunch.markAllInputs() - nextPhase(TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + initialPhase(inputCount, TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ val elem = inputBunch.dequeuePrefering(preferred) primaryOutputs.enqueueOutputElement(elem) }) @@ -341,5 +341,5 @@ private[akka] final class Concat(_settings: ActorFlowMaterializerSettings) exten primaryOutputs.enqueueOutputElement(elem) } - nextPhase(drainFirst) + initialPhase(inputCount, drainFirst) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala index 759a15c267..81de929af2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala @@ -272,7 +272,7 @@ private[akka] abstract class FanOut(val settings: ActorFlowMaterializerSettings, log.debug("fail due to: {}", e.getMessage) primaryInputs.cancel() outputBunch.cancel(e) - context.stop(self) + pump() } override def postStop(): Unit = { @@ -302,7 +302,7 @@ private[akka] object Broadcast { private[akka] class Broadcast(_settings: ActorFlowMaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { outputBunch.markAllOutputs() - nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + initialPhase(1, TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ val elem = primaryInputs.dequeueInputElement() outputBunch.enqueueMarked(elem) }) @@ -328,11 +328,11 @@ private[akka] class Balance(_settings: ActorFlowMaterializerSettings, _outputPor } if (waitForAllDownstreams) - nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + initialPhase(1, TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ nextPhase(runningPhase) }) else - nextPhase(runningPhase) + initialPhase(1, runningPhase) } /** @@ -349,7 +349,7 @@ private[akka] object Unzip { private[akka] class Unzip(_settings: ActorFlowMaterializerSettings) extends FanOut(_settings, outputCount = 2) { outputBunch.markAllOutputs() - nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + initialPhase(1, TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ primaryInputs.dequeueInputElement() match { case (a, b) ⇒ outputBunch.enqueue(0, a) diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanoutProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/FanoutProcessor.scala index fb3ac2c8a2..95a4bc4468 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanoutProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanoutProcessor.scala @@ -120,5 +120,5 @@ private[akka] class FanoutProcessorImpl( def afterFlush(): Unit = context.stop(self) - nextPhase(running) + initialPhase(1, running) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala index 49c8469ffa..063f70aaa4 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala @@ -117,7 +117,7 @@ private[akka] class FlexiMergeImpl[T, S <: Shape]( changeBehavior(mergeLogic.initialState) changeCompletionHandling(mergeLogic.initialCompletionHandling) - nextPhase(TransferPhase(precondition) { () ⇒ + initialPhase(inputCount, TransferPhase(precondition) { () ⇒ behavior.condition match { case read: ReadAny[t] ⇒ suppressCompletion() diff --git a/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala index a41d848c9d..2c6a787b52 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala @@ -122,7 +122,7 @@ private[akka] class FlexiRouteImpl[T, S <: Shape](_settings: ActorFlowMaterializ changeBehavior(routeLogic.initialState) changeCompletionHandling(routeLogic.initialCompletionHandling) - nextPhase(TransferPhase(precondition) { () ⇒ + initialPhase(1, TransferPhase(precondition) { () ⇒ val elem = primaryInputs.dequeueInputElement().asInstanceOf[T] behavior.condition match { case any: DemandFromAny ⇒ diff --git a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala index 1a44b1d8f5..4a910f1637 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala @@ -86,7 +86,7 @@ private[akka] class GroupByProcessorImpl(settings: ActorFlowMaterializerSettings } } - nextPhase(waitFirst) + initialPhase(1, waitFirst) override def invalidateSubstreamOutput(substream: SubstreamKey): Unit = { if ((pendingSubstreamOutput ne null) && substream == pendingSubstreamOutput.key) { diff --git a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala index 8028bebf29..dc70e5e092 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala @@ -62,5 +62,5 @@ private[akka] class PrefixAndTailImpl(_settings: ActorFlowMaterializerSettings, nextPhase(streamTailPhase(substreamOutput)) } - if (takeMax > 0) nextPhase(take) else nextPhase(takeEmpty) + if (takeMax > 0) initialPhase(1, take) else initialPhase(1, takeEmpty) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala index 6057147e60..c159c28a9a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala @@ -82,7 +82,7 @@ private[akka] class SplitWhenProcessorImpl(_settings: ActorFlowMaterializerSetti Drop } - nextPhase(waitFirst) + initialPhase(1, waitFirst) override def completeSubstreamOutput(substream: SubstreamKey): Unit = { if ((currentSubstream ne null) && substream == currentSubstream.key) nextPhase(ignoreUntilNewSubstream) diff --git a/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala index 44a3e69fc6..a9a4899d1f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala @@ -27,7 +27,7 @@ private[akka] class TimerTransformerProcessorsImpl( override def preStart(): Unit = { super.preStart() - nextPhase(running) + initialPhase(1, running) transformer.start(context) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala index cf278bfc50..d43ac5e877 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala @@ -125,6 +125,14 @@ private[akka] object NotInitialized extends TransferState { def isCompleted = false } +/** + * INTERNAL API + */ +private[akka] case class WaitingForUpstreamSubscription(remaining: Int, andThen: TransferPhase) extends TransferState { + def isReady = false + def isCompleted = false +} + /** * INTERNAL API */ @@ -146,9 +154,31 @@ private[akka] trait Pump { private var currentAction: () ⇒ Unit = () ⇒ throw new IllegalStateException("Pump has been not initialized with a phase") - final def nextPhase(phase: TransferPhase): Unit = { - transferState = phase.precondition - currentAction = phase.action + final def initialPhase(waitForUpstream: Int, andThen: TransferPhase): Unit = { + require(waitForUpstream >= 1, s"waitForUpstream must be >= 1 (was $waitForUpstream)") + if (transferState != NotInitialized) + throw new IllegalStateException(s"initialPhase expected NotInitialized, but was [$transferState]") + transferState = WaitingForUpstreamSubscription(waitForUpstream, andThen) + } + + def gotUpstreamSubscription(): Unit = { + transferState match { + case WaitingForUpstreamSubscription(1, andThen) ⇒ + transferState = andThen.precondition + currentAction = andThen.action + case WaitingForUpstreamSubscription(remaining, andThen) ⇒ + transferState = WaitingForUpstreamSubscription(remaining - 1, andThen) + case _ ⇒ // ok, initial phase not used, or passed already + } + pump() + } + + final def nextPhase(phase: TransferPhase): Unit = transferState match { + case WaitingForUpstreamSubscription(remaining, _) ⇒ + transferState = WaitingForUpstreamSubscription(remaining, phase) + case _ ⇒ + transferState = phase.precondition + currentAction = phase.action } final def isPumpFinished: Boolean = transferState.isCompleted diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala index a6852715a8..9e29b12400 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala @@ -35,6 +35,7 @@ private[akka] class BatchingActorInputBoundary(val size: Int, val name: String) private var nextInputElementCursor = 0 private var upstreamCompleted = false private var downstreamWaiting = false + private var downstreamCanceled = false private val IndexMask = size - 1 private def requestBatchSize = math.max(1, inputBuffer.length / 2) @@ -42,7 +43,9 @@ private[akka] class BatchingActorInputBoundary(val size: Int, val name: String) val subreceive: SubReceive = new SubReceive(waitingForUpstream) - def isFinished = (upstream ne null) && upstreamCompleted + def isFinished = upstreamCompleted && ((upstream ne null) || downstreamCanceled) + + def setDownstreamCanceled(): Unit = downstreamCanceled = true private def dequeue(): Any = { val elem = inputBuffer(nextInputElementCursor) @@ -113,8 +116,12 @@ private[akka] class BatchingActorInputBoundary(val size: Int, val name: String) private def onSubscribe(subscription: Subscription): Unit = { assert(subscription != null) - if (upstreamCompleted) subscription.cancel() - else { + if (upstreamCompleted) + subscription.cancel() + else if (downstreamCanceled) { + upstreamCompleted = true + subscription.cancel() + } else { upstream = subscription // Prefetch upstream.request(inputBuffer.length) @@ -326,7 +333,11 @@ private[akka] class ActorInterpreter(val settings: ActorFlowMaterializerSettings override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = { super.aroundReceive(receive, msg) - if (interpreter.isFinished && upstream.isFinished) context.stop(self) + + if (interpreter.isFinished) { + if (upstream.isFinished) context.stop(self) + else upstream.setDownstreamCanceled() + } } override def postStop(): Unit = {