diff --git a/akka-docs/src/main/paradox/stream/stages-overview.md b/akka-docs/src/main/paradox/stream/stages-overview.md index 722424625a..f798ce4e7c 100644 --- a/akka-docs/src/main/paradox/stream/stages-overview.md +++ b/akka-docs/src/main/paradox/stream/stages-overview.md @@ -1085,6 +1085,8 @@ Each upstream element will either be diverted to the given sink, or the downstre **completes** when upstream completes and no output is pending +**cancels** when any of the downstreams cancel + ---------------------------------------------------------------
diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala index 43679ea969..74e8fa7bb8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala @@ -89,13 +89,13 @@ class GraphPartitionSpec extends StreamSpec { c2.expectComplete() } - "cancel upstream when downstreams cancel" in assertAllStagesStopped { + "cancel upstream when all downstreams cancel if eagerCancel is false" in assertAllStagesStopped { val p1 = TestPublisher.probe[Int]() val c1 = TestSubscriber.probe[Int]() val c2 = TestSubscriber.probe[Int]() RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ - val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 })) + val partition = b.add(new Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }, false)) Source.fromPublisher(p1.getPublisher) ~> partition.in partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1) partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2) @@ -114,10 +114,38 @@ class GraphPartitionSpec extends StreamSpec { p1Sub.sendNext(2) c1.expectNext(2) sub1.cancel() + p1Sub.sendNext(9) + c2.expectNext(9) sub2.cancel() p1Sub.expectCancellation() } + "cancel upstream when any downstream cancel if eagerCancel is true" in assertAllStagesStopped { + val p1 = TestPublisher.probe[Int]() + val c1 = TestSubscriber.probe[Int]() + val c2 = TestSubscriber.probe[Int]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(new Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }, true)) + Source.fromPublisher(p1.getPublisher) ~> partition.in + partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2) + ClosedShape + }).run() + + val p1Sub = p1.expectSubscription() + val sub1 = c1.expectSubscription() + val sub2 = c2.expectSubscription() + sub1.request(3) + sub2.request(3) + p1Sub.sendNext(1) + p1Sub.sendNext(8) + c1.expectNext(1) + c2.expectNext(8) + sub1.cancel() + p1Sub.expectCancellation() + } + "work with merge" in assertAllStagesStopped { val s = Sink.seq[Int] val input = Set(5, 2, 9, 1, 1, 1, 10) @@ -190,4 +218,16 @@ class GraphPartitionSpec extends StreamSpec { odd.expectComplete() even.expectComplete() } + + "divertTo must cancel when any of the downstreams cancel" in assertAllStagesStopped { + val pub = TestPublisher.probe[Int]() + val odd = TestSubscriber.probe[Int]() + val even = TestSubscriber.probe[Int]() + Source.fromPublisher(pub.getPublisher).divertTo(Sink.fromSubscriber(odd), _ % 2 != 0).to(Sink.fromSubscriber(even)).run() + even.request(1) + pub.sendNext(2) + even.expectNext(2) + odd.cancel() + pub.expectCancellation() + } } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 64d5474382..14d1f009bb 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -1705,7 +1705,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' when all downstreams cancel + * '''Cancels when''' any of the downstreams cancel */ def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): javadsl.Flow[In, Out, Mat] = new Flow(delegate.divertTo(that, when.test)) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala index d0ecd72815..159b520fb2 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala @@ -197,26 +197,48 @@ object Broadcast { * '''Completes when''' upstream completes * * '''Cancels when''' - * when one of the downstreams cancel + * when any (eagerCancel=true) or all (eagerCancel=false) of the downstreams cancel */ object Partition { /** - * Create a new `Partition` stage with the specified input type. + * Create a new `Partition` stage with the specified input type, `eagerCancel` is `false`. * * @param outputCount number of output ports * @param partitioner function deciding which output each element will be targeted */ def create[T](outputCount: Int, partitioner: function.Function[T, Integer]): Graph[UniformFanOutShape[T, T], NotUsed] = - scaladsl.Partition(outputCount, partitioner = (t: T) ⇒ partitioner.apply(t)) + new scaladsl.Partition(outputCount, partitioner.apply) /** * Create a new `Partition` stage with the specified input type. * * @param outputCount number of output ports * @param partitioner function deciding which output each element will be targeted + * @param eagerCancel this stage cancels, when any (true) or all (false) of the downstreams cancel + */ + def create[T](outputCount: Int, partitioner: function.Function[T, Integer], eagerCancel: Boolean): Graph[UniformFanOutShape[T, T], NotUsed] = + new scaladsl.Partition(outputCount, partitioner.apply, eagerCancel) + + /** + * Create a new `Partition` stage with the specified input type, `eagerCancel` is `false`. + * + * @param clazz a type hint for this method + * @param outputCount number of output ports + * @param partitioner function deciding which output each element will be targeted */ def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Integer]): Graph[UniformFanOutShape[T, T], NotUsed] = - create(outputCount, partitioner) + new scaladsl.Partition(outputCount, partitioner.apply) + + /** + * Create a new `Partition` stage with the specified input type. + * + * @param clazz a type hint for this method + * @param outputCount number of output ports + * @param partitioner function deciding which output each element will be targeted + * @param eagerCancel this stage cancels, when any (true) or all (false) of the downstreams cancel + */ + def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Integer], eagerCancel: Boolean): Graph[UniformFanOutShape[T, T], NotUsed] = + new scaladsl.Partition(outputCount, partitioner.apply, eagerCancel) } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index c6d94dad0e..a5251d7cc7 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -761,7 +761,7 @@ final class Source[+Out, +Mat](delegate: scaladsl.Source[Out, Mat]) extends Grap * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' when all downstreams cancel + * '''Cancels when''' any of the downstreams cancel */ def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): javadsl.Source[Out, Mat] = new Source(delegate.divertTo(that, when.test)) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala index 9dfe6eb20b..fe8a0909a2 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala @@ -1160,7 +1160,7 @@ class SubFlow[-In, +Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Flo * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' when all downstreams cancel + * '''Cancels when''' any of the downstreams cancel */ def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): SubFlow[In, Out, Mat] = new SubFlow(delegate.divertTo(that, when.test)) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala index a87787b607..196c25d737 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala @@ -1152,7 +1152,7 @@ class SubSource[+Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Source * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' when all downstreams cancel + * '''Cancels when''' any of the downstreams cancel */ def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): SubSource[Out, Mat] = new SubSource(delegate.divertTo(that, when.test)) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 4300b50569..2eff5dcb0d 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -2308,14 +2308,14 @@ trait FlowOps[+Out, +Mat] { * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' when all downstreams cancel + * '''Cancels when''' any of the downstreams cancel */ def divertTo(that: Graph[SinkShape[Out], _], when: Out ⇒ Boolean): Repr[Out] = via(divertToGraph(that, when)) protected def divertToGraph[M](that: Graph[SinkShape[Out], M], when: Out ⇒ Boolean): Graph[FlowShape[Out @uncheckedVariance, Out], M] = GraphDSL.create(that) { implicit b ⇒ r ⇒ import GraphDSL.Implicits._ - val partition = b.add(Partition[Out](2, out ⇒ if (when(out)) 1 else 0)) + val partition = b.add(new Partition[Out](2, out ⇒ if (when(out)) 1 else 0, true)) partition.out(1) ~> r FlowShape(partition.in, partition.out(0)) } @@ -2335,7 +2335,6 @@ trait FlowOps[+Out, +Mat] { * asynchronously. */ def async: Repr[Out] - } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 82ffa83e3a..a37aadce93 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -628,16 +628,19 @@ final class Broadcast[T](val outputPorts: Int, val eagerCancel: Boolean) extends } object Partition { - + // FIXME make `PartitionOutOfBoundsException` a `final` class when possible case class PartitionOutOfBoundsException(msg: String) extends IndexOutOfBoundsException(msg) with NoStackTrace /** - * Create a new `Partition` stage with the specified input type. + * Create a new `Partition` stage with the specified input type. This method sets `eagerCancel` to `false`. + * To specify a different value for the `eagerCancel` parameter, then instantiate Partition using the constructor. + * + * If `eagerCancel` is true, partition cancels upstream if any of its downstreams cancel, if false, when all have cancelled. * * @param outputPorts number of output ports * @param partitioner function deciding which output each element will be targeted - */ - def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner) + */ // FIXME BC add `eagerCancel: Boolean = false` parameter + def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner, false) } /** @@ -650,14 +653,19 @@ object Partition { * * '''Completes when''' upstream completes and no output is pending * - * '''Cancels when''' - * when all downstreams cancel + * '''Cancels when''' all downstreams have cancelled (eagerCancel=false) or one downstream cancels (eagerCancel=true) */ -final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int) extends GraphStage[UniformFanOutShape[T, T]] { +final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int, val eagerCancel: Boolean) extends GraphStage[UniformFanOutShape[T, T]] { + + /** + * Sets `eagerCancel` to `false`. + */ + @deprecated("Use the constructor which also specifies the `eagerCancel` parameter") + def this(outputPorts: Int, partitioner: T ⇒ Int) = this(outputPorts, partitioner, false) val in: Inlet[T] = Inlet[T]("Partition.in") - val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i)) + val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i)) // FIXME BC make this immutable.IndexedSeq as type + Vector as concret impl override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler { @@ -690,11 +698,10 @@ final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int) exten setHandler(in, this) - out.zipWithIndex.foreach { + out.iterator.zipWithIndex.foreach { case (o, idx) ⇒ setHandler(o, new OutHandler { override def onPull() = { - if (outPendingElem != null) { val elem = outPendingElem.asInstanceOf[T] if (idx == outPendingIdx) { @@ -711,24 +718,25 @@ final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int) exten pull(in) } - override def onDownstreamFinish(): Unit = { - downstreamRunning -= 1 - if (downstreamRunning == 0) - completeStage() - else if (outPendingElem != null) { - if (idx == outPendingIdx) { - outPendingElem = null - if (!hasBeenPulled(in)) - pull(in) + override def onDownstreamFinish(): Unit = + if (eagerCancel) completeStage() + else { + downstreamRunning -= 1 + if (downstreamRunning == 0) + completeStage() + else if (outPendingElem != null) { + if (idx == outPendingIdx) { + outPendingElem = null + if (!hasBeenPulled(in)) + pull(in) + } } } - } }) } } override def toString = s"Partition($outputPorts)" - } object Balance {