diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala index 9a005bef44..8c806e17ab 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala @@ -239,11 +239,11 @@ class ActorGraphInterpreterSpec extends AkkaSpec { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - setHandler(shape.outlet, new OutHandler { + setHandler(shape.out, new OutHandler { override def onPull(): Unit = { completeStage() // This cannot be propagated now since the stage is already closed - push(shape.outlet, -1) + push(shape.out, -1) } }) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala index cddc73939f..c7c37eaa28 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala @@ -8,7 +8,7 @@ import scala.concurrent.{ Future, Await } import scala.concurrent.duration._ import scala.util.Try import scala.util.control.NoStackTrace -import akka.stream.{ Attributes, ActorMaterializer, ActorMaterializerSettings } +import akka.stream._ import org.reactivestreams.Subscriber import akka.stream.testkit._ import akka.stream.testkit.Utils._ @@ -90,6 +90,48 @@ class FlowPrefixAndTailSpec extends AkkaSpec { subscriber.expectSubscriptionAndComplete() } + "throw if tail is attempted to be materialized twice" in assertAllStagesStopped { + val futureSink = newHeadSink + val fut = Source(1 to 2).prefixAndTail(1).runWith(futureSink) + val (takes, tail) = Await.result(fut, 3.seconds) + takes should be(Seq(1)) + + val subscriber1 = TestSubscriber.probe[Int]() + tail.to(Sink(subscriber1)).run() + + val subscriber2 = TestSubscriber.probe[Int]() + tail.to(Sink(subscriber2)).run() + subscriber2.expectSubscriptionAndError().getMessage should ===("Tail Source cannot be materialized more than once.") + + subscriber1.requestNext(2).expectComplete() + + } + + "signal error if substream has been not subscribed in time" in assertAllStagesStopped { + val tightTimeoutMaterializer = + ActorMaterializer(ActorMaterializerSettings(system) + .withSubscriptionTimeoutSettings( + StreamSubscriptionTimeoutSettings(StreamSubscriptionTimeoutTerminationMode.cancel, 500.millisecond))) + + val futureSink = newHeadSink + val fut = Source(1 to 2).prefixAndTail(1).runWith(futureSink)(tightTimeoutMaterializer) + val (takes, tail) = Await.result(fut, 3.seconds) + takes should be(Seq(1)) + + val subscriber = TestSubscriber.probe[Int]() + Thread.sleep(1000) + + tail.to(Sink(subscriber)).run()(tightTimeoutMaterializer) + subscriber.expectSubscriptionAndError().getMessage should ===("Tail Source has not been materialized in 500 milliseconds.") + } + + "shut down main stage if substream is empty, even when not subscribed" in assertAllStagesStopped { + val futureSink = newHeadSink + val fut = Source.single(1).prefixAndTail(1).runWith(futureSink) + val (takes, tail) = Await.result(fut, 3.seconds) + takes should be(Seq(1)) + } + "handle onError when no substream open" in assertAllStagesStopped { val publisher = TestPublisher.manualProbe[Int]() val subscriber = TestSubscriber.manualProbe[(immutable.Seq[Int], Source[Int, _])]() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala index 559ae5b3ee..600eb758a8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala @@ -110,7 +110,7 @@ class GraphMatValueSpec extends AkkaSpec { val foldFlow: Flow[Int, Int, Future[Int]] = Flow.fromGraph(GraphDSL.create(Sink.fold[Int, Int](0)(_ + _)) { implicit builder ⇒ fold ⇒ - FlowShape(fold.inlet, builder.materializedValue.mapAsync(4)(identity).outlet) + FlowShape(fold.in, builder.materializedValue.mapAsync(4)(identity).outlet) }) Await.result(Source(1 to 10).via(foldFlow).runWith(Sink.head), 3.seconds) should ===(55) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala index d1f36de2c9..6771dbde22 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala @@ -38,22 +38,20 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { implicit val dispatcher = system.dispatcher implicit val materializer = ActorMaterializer(settings) - "groupBy" must { + "groupBy and splitwhen" must { "timeout and cancel substream publishers when no-one subscribes to them after some time (time them out)" in assertAllStagesStopped { - val publisherProbe = TestPublisher.manualProbe[Int]() + val publisherProbe = TestPublisher.probe[Int]() val publisher = Source(publisherProbe).groupBy(3, _ % 3).lift(_ % 3).runWith(Sink.publisher(false)) val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]() publisher.subscribe(subscriber) - val upstreamSubscription = publisherProbe.expectSubscription() - val downstreamSubscription = subscriber.expectSubscription() downstreamSubscription.request(100) - upstreamSubscription.sendNext(1) - upstreamSubscription.sendNext(2) - upstreamSubscription.sendNext(3) + publisherProbe.sendNext(1) + publisherProbe.sendNext(2) + publisherProbe.sendNext(3) val (_, s1) = subscriber.expectNext() // should not break normal usage @@ -79,42 +77,38 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { val f = s3.runWith(Sink.head).recover { case _: SubscriptionTimeoutException ⇒ "expected" } Await.result(f, 300.millis) should equal("expected") - upstreamSubscription.sendComplete() + publisherProbe.sendComplete() } "timeout and stop groupBy parent actor if none of the substreams are actually consumed" in assertAllStagesStopped { - val publisherProbe = TestPublisher.manualProbe[Int]() + val publisherProbe = TestPublisher.probe[Int]() val publisher = Source(publisherProbe).groupBy(2, _ % 2).lift(_ % 2).runWith(Sink.publisher(false)) val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]() publisher.subscribe(subscriber) - val upstreamSubscription = publisherProbe.expectSubscription() - val downstreamSubscription = subscriber.expectSubscription() downstreamSubscription.request(100) - upstreamSubscription.sendNext(1) - upstreamSubscription.sendNext(2) - upstreamSubscription.sendNext(3) - upstreamSubscription.sendComplete() + publisherProbe.sendNext(1) + publisherProbe.sendNext(2) + publisherProbe.sendNext(3) + publisherProbe.sendComplete() val (_, s1) = subscriber.expectNext() val (_, s2) = subscriber.expectNext() } "not timeout and cancel substream publishers when they have been subscribed to" in { - val publisherProbe = TestPublisher.manualProbe[Int]() + val publisherProbe = TestPublisher.probe[Int]() val publisher = Source(publisherProbe).groupBy(2, _ % 2).lift(_ % 2).runWith(Sink.publisher(false)) val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]() publisher.subscribe(subscriber) - val upstreamSubscription = publisherProbe.expectSubscription() - val downstreamSubscription = subscriber.expectSubscription() downstreamSubscription.request(100) - upstreamSubscription.sendNext(1) - upstreamSubscription.sendNext(2) + publisherProbe.sendNext(1) + publisherProbe.sendNext(2) val (_, s1) = subscriber.expectNext() // should not break normal usage @@ -136,8 +130,8 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { s2Sub.request(100) s2SubscriberProbe.expectNext(2) s1Sub.request(100) - upstreamSubscription.sendNext(3) - upstreamSubscription.sendNext(4) + publisherProbe.sendNext(3) + publisherProbe.sendNext(4) s1SubscriberProbe.expectNext(3) s2SubscriberProbe.expectNext(4) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index c3f8e047a5..dee5691de6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -277,7 +277,6 @@ private[akka] object ActorProcessorFactory { val settings = materializer.effectiveSettings(att) op match { case GroupBy(maxSubstreams, f, _) ⇒ (GroupByProcessorImpl.props(settings, maxSubstreams, f), ()) - case PrefixAndTail(n, _) ⇒ (PrefixAndTailImpl.props(settings, n), ()) case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) case DirectProcessor(p, m) ⇒ throw new AssertionError("DirectProcessor cannot end up in ActorProcessorFactory") } diff --git a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala deleted file mode 100644 index 553edc5e03..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl - -import scala.collection.immutable -import akka.stream.ActorMaterializerSettings -import akka.stream.scaladsl.Source -import akka.actor.{ Deploy, Props } - -/** - * INTERNAL API - */ -private[akka] object PrefixAndTailImpl { - def props(settings: ActorMaterializerSettings, takeMax: Int): Props = - Props(new PrefixAndTailImpl(settings, takeMax)).withDeploy(Deploy.local) -} - -/** - * INTERNAL API - */ -private[akka] class PrefixAndTailImpl(_settings: ActorMaterializerSettings, val takeMax: Int) - extends MultiStreamOutputProcessor(_settings) { - - import MultiStreamOutputProcessor._ - - var taken = immutable.Vector.empty[Any] - var left = takeMax - - val take = TransferPhase(primaryInputs.NeedsInputOrComplete && primaryOutputs.NeedsDemand) { () ⇒ - if (primaryInputs.inputsDepleted) emitEmptyTail() - else { - val elem = primaryInputs.dequeueInputElement() - taken :+= elem - left -= 1 - if (left <= 0) { - if (primaryInputs.inputsDepleted) emitEmptyTail() - else emitNonEmptyTail() - } - } - } - - def streamTailPhase(substream: SubstreamOutput) = TransferPhase(primaryInputs.NeedsInput && substream.NeedsDemand) { () ⇒ - substream.enqueueOutputElement(primaryInputs.dequeueInputElement()) - } - - val takeEmpty = TransferPhase(primaryOutputs.NeedsDemand) { () ⇒ - if (primaryInputs.inputsDepleted) emitEmptyTail() - else emitNonEmptyTail() - } - - def emitEmptyTail(): Unit = { - primaryOutputs.enqueueOutputElement((taken, Source.empty)) - nextPhase(completedPhase) - } - - def emitNonEmptyTail(): Unit = { - val substreamOutput = createSubstreamOutput() - val substreamFlow = Source(substreamOutput) // substreamOutput is a Publisher - primaryOutputs.enqueueOutputElement((taken, substreamFlow)) - primaryOutputs.complete() - nextPhase(streamTailPhase(substreamOutput)) - } - - if (takeMax > 0) initialPhase(1, take) else initialPhase(1, takeEmpty) -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index ed5e79fc9d..e5875c9878 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -208,10 +208,6 @@ private[stream] object Stages { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } - final case class PrefixAndTail(n: Int, attributes: Attributes = prefixAndTail) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - final case class Split(p: Any ⇒ SplitDecision, attributes: Attributes = split) extends StageModule { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index dc2c7900ed..7de5f225ee 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -214,6 +214,7 @@ private[stream] object ActorGraphInterpreter { private var downstreamCompleted = false // when upstream failed before we got the exposed publisher private var upstreamFailed: Option[Throwable] = None + private var upstreamCompleted: Boolean = false private def onNext(elem: Any): Unit = { downstreamDemand -= 1 @@ -221,21 +222,21 @@ private[stream] object ActorGraphInterpreter { } private def complete(): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true + // No need to complete if had already been cancelled, or we closed earlier + if (!(upstreamCompleted || downstreamCompleted)) { + upstreamCompleted = true if (exposedPublisher ne null) exposedPublisher.shutdown(None) if (subscriber ne null) tryOnComplete(subscriber) } } def fail(e: Throwable): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true + // No need to fail if had already been cancelled, or we closed earlier + if (!(downstreamCompleted || upstreamCompleted)) { + upstreamCompleted = true + upstreamFailed = Some(e) if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e)) if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e) - } else if (exposedPublisher == null && upstreamFailed.isEmpty) { - // fail called before the exposed publisher arrived, we must store it and fail when we're first able to - upstreamFailed = Some(e) } } @@ -258,6 +259,7 @@ private[stream] object ActorGraphInterpreter { if (subscriber eq null) { subscriber = sub tryOnSubscribe(subscriber, new BoundarySubscription(actor, id)) + if (GraphInterpreter.Debug) println(s"${interpreter.Name} subscribe subscriber=$sub") } else rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}") } @@ -267,7 +269,8 @@ private[stream] object ActorGraphInterpreter { case _: Some[_] ⇒ publisher.shutdown(upstreamFailed) case _ ⇒ - exposedPublisher = publisher + if (upstreamCompleted) publisher.shutdown(None) + else exposedPublisher = publisher } } @@ -322,6 +325,7 @@ private[stream] class ActorGraphInterpreter( private val outputs = Array.tabulate(shape.outlets.size)(new ActorOutputBoundary(self, _)) private var subscribesPending = inputs.length + private var publishersPending = outputs.length /* * Limits the number of events processed by the interpreter before scheduling @@ -398,21 +402,28 @@ private[stream] class ActorGraphInterpreter( case SubscribePending(id: Int) ⇒ outputs(id).subscribePending() case ExposedPublisher(id, publisher) ⇒ + publishersPending -= 1 outputs(id).exposedPublisher(publisher) } private def waitShutdown: Receive = { + case ExposedPublisher(id, publisher) ⇒ + outputs(id).exposedPublisher(publisher) + publishersPending -= 1 + if (canShutDown) context.stop(self) case OnSubscribe(_, sub) ⇒ tryCancel(sub) subscribesPending -= 1 - if (subscribesPending == 0) context.stop(self) + if (canShutDown) context.stop(self) case ReceiveTimeout ⇒ tryAbort(new TimeoutException("Streaming actor has been already stopped processing (normally), but not all of its " + - s"inputs have been subscribed in [${settings.subscriptionTimeoutSettings.timeout}}]. Aborting actor now.")) + s"inputs or outputs have been subscribed in [${settings.subscriptionTimeoutSettings.timeout}}]. Aborting actor now.")) case _ ⇒ // Ignore, there is nothing to do anyway } + private def canShutDown: Boolean = subscribesPending + publishersPending == 0 + private def runBatch(): Unit = { try { val effectiveLimit = { @@ -425,7 +436,7 @@ private[stream] class ActorGraphInterpreter( interpreter.execute(effectiveLimit) if (interpreter.isCompleted) { // Cannot stop right away if not completely subscribed - if (subscribesPending == 0) context.stop(self) + if (canShutDown) context.stop(self) else { context.become(waitShutdown) context.setReceiveTimeout(settings.subscriptionTimeoutSettings.timeout) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala index a182cb4fe0..bf03be2c71 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala @@ -3,7 +3,10 @@ */ package akka.stream.impl.fusing +import java.util.concurrent.atomic.AtomicReference + import akka.stream._ +import akka.stream.impl.SubscriptionTimeoutException import akka.stream.stage._ import akka.stream.scaladsl._ import akka.stream.actor.ActorSubscriberMessage @@ -11,8 +14,13 @@ import akka.stream.actor.ActorSubscriberMessage._ import akka.stream.actor.ActorPublisherMessage import akka.stream.actor.ActorPublisherMessage._ import java.{ util ⇒ ju } +import scala.collection.immutable import scala.concurrent._ +import scala.concurrent.duration.FiniteDuration +/** + * INTERNAL API + */ final class FlattenMerge[T, M](breadth: Int) extends GraphStage[FlowShape[Graph[SourceShape[T], M], T]] { private val in = Inlet[Graph[SourceShape[T], M]]("flatten.in") private val out = Outlet[T]("flatten.out") @@ -218,3 +226,191 @@ private[fusing] object StreamOfStreams { } } } + +/** + * INTERNAL API + */ +object PrefixAndTail { + + sealed trait MaterializationState + case object NotMaterialized extends MaterializationState + case object AlreadyMaterialized extends MaterializationState + case object TimedOut extends MaterializationState + + case object NormalCompletion extends MaterializationState + case class FailureCompletion(ex: Throwable) extends MaterializationState + + trait TailInterface[T] { + def pushSubstream(elem: T): Unit + def completeSubstream(): Unit + def failSubstream(ex: Throwable) + } + + final class TailSource[T]( + timeout: FiniteDuration, + register: TailInterface[T] ⇒ Unit, + pullParent: Unit ⇒ Unit, + cancelParent: Unit ⇒ Unit) extends GraphStage[SourceShape[T]] { + val out: Outlet[T] = Outlet("Tail.out") + val materializationState = new AtomicReference[MaterializationState](NotMaterialized) + override val shape: SourceShape[T] = SourceShape(out) + + private final class TailSourceLogic(_shape: Shape) extends GraphStageLogic(_shape) with OutHandler with TailInterface[T] { + setHandler(out, this) + + override def preStart(): Unit = { + materializationState.getAndSet(AlreadyMaterialized) match { + case AlreadyMaterialized ⇒ + failStage(new IllegalStateException("Tail Source cannot be materialized more than once.")) + case TimedOut ⇒ + // Already detached from parent + failStage(new SubscriptionTimeoutException(s"Tail Source has not been materialized in $timeout.")) + case NormalCompletion ⇒ + // Already detached from parent + completeStage() + case FailureCompletion(ex) ⇒ + // Already detached from parent + failStage(ex) + case NotMaterialized ⇒ + register(this) + } + + } + + private val onParentPush = getAsyncCallback[T](push(out, _)) + private val onParentFinish = getAsyncCallback[Unit](_ ⇒ completeStage()) + private val onParentFailure = getAsyncCallback[Throwable](failStage) + + override def pushSubstream(elem: T): Unit = onParentPush.invoke(elem) + override def completeSubstream(): Unit = onParentFinish.invoke(()) + override def failSubstream(ex: Throwable): Unit = onParentFailure.invoke(ex) + + override def onPull(): Unit = pullParent() + override def onDownstreamFinish(): Unit = cancelParent() + } + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TailSourceLogic(shape) + } + +} + +/** + * INTERNAL API + */ +final class PrefixAndTail[T](n: Int) extends GraphStage[FlowShape[T, (immutable.Seq[T], Source[T, Unit])]] { + val in: Inlet[T] = Inlet("PrefixAndTail.in") + val out: Outlet[(immutable.Seq[T], Source[T, Unit])] = Outlet("PrefixAndTail.out") + override val shape: FlowShape[T, (immutable.Seq[T], Source[T, Unit])] = FlowShape(in, out) + + override def initialAttributes = Attributes.name("PrefixAndTail") + + private final class PrefixAndTailLogic(_shape: Shape) extends TimerGraphStageLogic(_shape) with OutHandler with InHandler { + import PrefixAndTail._ + + private var left = if (n < 0) 0 else n + private var builder = Vector.newBuilder[T] + private var tailSource: TailSource[T] = null + private var tail: TailInterface[T] = null + builder.sizeHint(left) + private var pendingCompletion: MaterializationState = null + + private val SubscriptionTimer = "SubstreamSubscriptionTimer" + + private val onSubstreamPull = getAsyncCallback[Unit](_ ⇒ pull(in)) + private val onSubstreamFinish = getAsyncCallback[Unit](_ ⇒ completeStage()) + private val onSubstreamRegister = getAsyncCallback[TailInterface[T]] { tailIf ⇒ + tail = tailIf + cancelTimer(SubscriptionTimer) + pendingCompletion match { + case NormalCompletion ⇒ + tail.completeSubstream() + completeStage() + case FailureCompletion(ex) ⇒ + tail.failSubstream(ex) + completeStage() + case _ ⇒ + } + } + + override protected def onTimer(timerKey: Any): Unit = + if (tailSource.materializationState.compareAndSet(NotMaterialized, TimedOut)) completeStage() + + private def prefixComplete = builder eq null + private def waitingSubstreamRegistration = tail eq null + + private def openSubstream(): Source[T, Unit] = { + val timeout = ActorMaterializer.downcast(interpreter.materializer).settings.subscriptionTimeoutSettings.timeout + tailSource = new TailSource[T](timeout, onSubstreamRegister.invoke, onSubstreamPull.invoke, onSubstreamFinish.invoke) + scheduleOnce(SubscriptionTimer, timeout) + builder = null + Source.fromGraph(tailSource) + } + + // Needs to keep alive if upstream completes but substream has been not yet materialized + override def keepGoingAfterAllPortsClosed: Boolean = true + + override def onPush(): Unit = { + if (prefixComplete) { + tail.pushSubstream(grab(in)) + } else { + builder += grab(in) + left -= 1 + if (left == 0) { + push(out, (builder.result(), openSubstream())) + complete(out) + } else pull(in) + } + } + override def onPull(): Unit = { + if (left == 0) { + push(out, (Nil, openSubstream())) + complete(out) + } else pull(in) + } + + override def onUpstreamFinish(): Unit = { + if (!prefixComplete) { + // This handles the unpulled out case as well + emit(out, (builder.result, Source.empty), () ⇒ completeStage()) + } else { + if (waitingSubstreamRegistration) { + // Detach if possible. + // This allows this stage to complete without waiting for the substream to be materialized, since that + // is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be + // able to signal completion normally soon. + if (tailSource.materializationState.compareAndSet(NotMaterialized, NormalCompletion)) completeStage() + else pendingCompletion = NormalCompletion + } else { + tail.completeSubstream() + completeStage() + } + } + } + + override def onUpstreamFailure(ex: Throwable): Unit = { + if (prefixComplete) { + if (waitingSubstreamRegistration) { + // Detach if possible. + // This allows this stage to complete without waiting for the substream to be materialized, since that + // is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be + // able to signal completion normally soon. + if (tailSource.materializationState.compareAndSet(NotMaterialized, FailureCompletion(ex))) failStage(ex) + else pendingCompletion = FailureCompletion(ex) + } else { + tail.failSubstream(ex) + completeStage() + } + } else failStage(ex) + } + + override def onDownstreamFinish(): Unit = { + if (!prefixComplete) completeStage() + // Otherwise substream is open, ignore + } + + setHandler(in, this) + setHandler(out, this) + } + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new PrefixAndTailLogic(shape) +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index abe38b1c5e..e76e34c6d1 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -910,7 +910,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels or substream cancels */ def prefixAndTail[U >: Out](n: Int): Repr[(immutable.Seq[Out], Source[U, Unit])] = - deprecatedAndThen(PrefixAndTail(n)) + via(new PrefixAndTail[Out](n)) /** * This operation demultiplexes the incoming stream into separate output