diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBalanceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBalanceSpec.scala index 0ce2501492..3e224574e6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBalanceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBalanceSpec.scala @@ -44,6 +44,76 @@ class GraphBalanceSpec extends AkkaSpec { c2.expectComplete() } + "support waiting for demand from all downstream subscriptions" in { + val s1 = StreamTestKit.SubscriberProbe[Int]() + val p2Sink = Sink.publisher[Int] + + val m = FlowGraph { implicit b ⇒ + val balance = Balance[Int]("balance", waitForAllDownstreams = true) + Source(List(1, 2, 3)) ~> balance + balance ~> Sink(s1) + balance ~> p2Sink + }.run() + + val p2 = m.get(p2Sink) + + val sub1 = s1.expectSubscription() + sub1.request(1) + s1.expectNoMsg(200.millis) + + val s2 = StreamTestKit.SubscriberProbe[Int]() + p2.subscribe(s2) + val sub2 = s2.expectSubscription() + + // still no demand from s2 + s1.expectNoMsg(200.millis) + + sub2.request(2) + s1.expectNext(1) + s2.expectNext(2) + s2.expectNext(3) + s1.expectComplete() + s2.expectComplete() + } + + "support waiting for demand from all non-cancelled downstream subscriptions" in { + val s1 = StreamTestKit.SubscriberProbe[Int]() + val p2Sink = Sink.publisher[Int] + val p3Sink = Sink.publisher[Int] + + val m = FlowGraph { implicit b ⇒ + val balance = Balance[Int]("balance", waitForAllDownstreams = true) + Source(List(1, 2, 3)) ~> balance + balance ~> Sink(s1) + balance ~> p2Sink + balance ~> p3Sink + }.run() + + val p2 = m.get(p2Sink) + val p3 = m.get(p3Sink) + + val sub1 = s1.expectSubscription() + sub1.request(1) + + val s2 = StreamTestKit.SubscriberProbe[Int]() + p2.subscribe(s2) + val sub2 = s2.expectSubscription() + + val s3 = StreamTestKit.SubscriberProbe[Int]() + p3.subscribe(s3) + val sub3 = s3.expectSubscription() + + sub2.request(2) + s1.expectNoMsg(200.millis) + sub3.cancel() + + s1.expectNext(1) + s2.expectNext(2) + s2.expectNext(3) + s1.expectComplete() + s2.expectComplete() + } + "work with 5-way balance" in { val f1 = Sink.future[Seq[Int]] val f2 = Sink.future[Seq[Int]] @@ -52,7 +122,7 @@ class GraphBalanceSpec extends AkkaSpec { val f5 = Sink.future[Seq[Int]] val g = FlowGraph { implicit b ⇒ - val balance = Balance[Int]("balance") + val balance = Balance[Int]("balance", waitForAllDownstreams = true) Source(0 to 14) ~> balance balance ~> Flow[Int].grouped(15) ~> f1 balance ~> Flow[Int].grouped(15) ~> f2 @@ -68,7 +138,7 @@ class GraphBalanceSpec extends AkkaSpec { val numElementsForSink = 10000 val f1, f2, f3 = Sink.fold[Int, Int](0)(_ + _) val g = FlowGraph { implicit b ⇒ - val balance = Balance[Int]("balance") + val balance = Balance[Int]("balance", waitForAllDownstreams = true) Source(Stream.fill(10000 * 3)(1)) ~> balance ~> f1 balance ~> f2 balance ~> f3 diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala index 8bf8e87f0e..1fc693feb5 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala @@ -97,7 +97,7 @@ private[akka] object Ast { override def name = "broadcast" } - case object Balance extends FanOutAstNode { + case class Balance(waitForAllDownstreams: Boolean) extends FanOutAstNode { override def name = "balance" } @@ -255,8 +255,8 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting val impl = op match { case Ast.Broadcast ⇒ actorOf(Broadcast.props(settings, outputCount).withDispatcher(settings.dispatcher), actorName) - case Ast.Balance ⇒ - actorOf(Balance.props(settings, outputCount).withDispatcher(settings.dispatcher), actorName) + case Ast.Balance(waitForAllDownstreams) ⇒ + actorOf(Balance.props(settings, outputCount, waitForAllDownstreams).withDispatcher(settings.dispatcher), actorName) case Ast.Unzip ⇒ actorOf(Unzip.props(settings).withDispatcher(settings.dispatcher), actorName) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala index d285367964..6924d12e4c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala @@ -162,7 +162,8 @@ private[akka] object FanOut { if (marked(id) && !cancelled(id)) markedCancelled += 1 cancelled(id) = true outputs(id).subreceive(Cancel(null)) - case SubstreamSubscribePending(id) ⇒ outputs(id).subreceive(SubscribePending) + case SubstreamSubscribePending(id) ⇒ + outputs(id).subreceive(SubscribePending) }) } @@ -233,20 +234,27 @@ private[akka] class Broadcast(_settings: MaterializerSettings, _outputPorts: Int * INTERNAL API */ private[akka] object Balance { - def props(settings: MaterializerSettings, outputPorts: Int): Props = - Props(new Balance(settings, outputPorts)) + def props(settings: MaterializerSettings, outputPorts: Int, waitForAllDownstreams: Boolean): Props = + Props(new Balance(settings, outputPorts, waitForAllDownstreams)) } /** * INTERNAL API */ -private[akka] class Balance(_settings: MaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { +private[akka] class Balance(_settings: MaterializerSettings, _outputPorts: Int, waitForAllDownstreams: Boolean) extends FanOut(_settings, _outputPorts) { (0 until outputPorts) foreach outputBunch.markOutput - nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AnyOfMarkedOutputs) { () ⇒ + val runningPhase = TransferPhase(primaryInputs.NeedsInput && outputBunch.AnyOfMarkedOutputs) { () ⇒ val elem = primaryInputs.dequeueInputElement() outputBunch.enqueueAndYield(elem) - }) + } + + if (waitForAllDownstreams) + nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + nextPhase(runningPhase) + }) + else + nextPhase(runningPhase) } /** @@ -275,7 +283,7 @@ private[akka] class Unzip(_settings: MaterializerSettings) extends FanOut(_setti case t ⇒ throw new IllegalArgumentException( - s"Unable to unzip elements of type {t.getClass.getName}, " + + s"Unable to unzip elements of type ${t.getClass.getName}, " + s"can only handle Tuple2 and akka.japi.Pair!") } }) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/FlowGraph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/FlowGraph.scala index 1fc0f7a5fb..c520f411ae 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/FlowGraph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/FlowGraph.scala @@ -40,7 +40,7 @@ object Merge { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](): Merge[T] = new Merge(new scaladsl.Merge[T](None)) + def create[T](): Merge[T] = create(name = null) /** * Create a new anonymous `Merge` vertex with the specified output type. @@ -56,7 +56,7 @@ object Merge { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): Merge[T] = new Merge(new scaladsl.Merge[T](Some(name))) + def create[T](name: String): Merge[T] = new Merge(new scaladsl.Merge[T](Option(name))) /** * Create a named `Merge` vertex with the specified output type. @@ -85,7 +85,7 @@ object MergePreferred { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](): MergePreferred[T] = new MergePreferred(new scaladsl.MergePreferred[T](None)) + def create[T](): MergePreferred[T] = create(name = null) /** * Create a new anonymous `MergePreferred` vertex with the specified output type. @@ -93,7 +93,7 @@ object MergePreferred { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](clazz: Class[T]): MergePreferred[T] = new MergePreferred(new scaladsl.MergePreferred[T](None)) + def create[T](clazz: Class[T]): MergePreferred[T] = create[T]() /** * Create a named `MergePreferred` vertex with the specified output type. @@ -101,7 +101,7 @@ object MergePreferred { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): MergePreferred[T] = new MergePreferred(new scaladsl.MergePreferred[T](Some(name))) + def create[T](name: String): MergePreferred[T] = new MergePreferred(new scaladsl.MergePreferred[T](Option(name))) /** * Create a named `MergePreferred` vertex with the specified output type. @@ -109,7 +109,7 @@ object MergePreferred { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](clazz: Class[T], name: String): MergePreferred[T] = new MergePreferred(new scaladsl.MergePreferred[T](Some(name))) + def create[T](clazz: Class[T], name: String): MergePreferred[T] = create[T](name) } /** @@ -130,7 +130,7 @@ object Broadcast { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](): Broadcast[T] = new Broadcast(new scaladsl.Broadcast(None)) + def create[T](): Broadcast[T] = create(name = null) /** * Create a new anonymous `Broadcast` vertex with the specified input type. @@ -146,7 +146,7 @@ object Broadcast { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): Broadcast[T] = new Broadcast(new scaladsl.Broadcast(Some(name))) + def create[T](name: String): Broadcast[T] = new Broadcast(new scaladsl.Broadcast(Option(name))) /** * Create a named `Broadcast` vertex with the specified input type. @@ -173,7 +173,7 @@ object Balance { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](): Balance[T] = new Balance(new scaladsl.Balance(None)) + def create[T](): Balance[T] = create(name = null) /** * Create a new anonymous `Balance` vertex with the specified input type. @@ -189,7 +189,8 @@ object Balance { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): Balance[T] = new Balance(new scaladsl.Balance(Some(name))) + def create[T](name: String): Balance[T] = + new Balance(new scaladsl.Balance(Option(name), waitForAllDownstreams = false)) /** * Create a named `Balance` vertex with the specified input type. @@ -207,6 +208,13 @@ object Balance { */ class Balance[T](delegate: scaladsl.Balance[T]) extends javadsl.Junction[T] { override def asScala: scaladsl.Balance[T] = delegate + + /** + * If you use `withWaitForAllDownstreams(true)` the returned `Balance` will not start emitting + * elements to downstream outputs until all of them have requested at least one element. + */ + def withWaitForAllDowstreams(enabled: Boolean): Balance[T] = + new Balance(new scaladsl.Balance(delegate.name, delegate.waitForAllDownstreams)) } object Zip { @@ -225,7 +233,7 @@ object Zip { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[A, B](left: Class[A], right: Class[B]): Zip[A, B] = create[A, B](name = null) + def create[A, B](left: Class[A], right: Class[B]): Zip[A, B] = create[A, B]() /** * Create a named `Zip` vertex with the specified input types. @@ -276,8 +284,7 @@ final class Zip[A, B] private (delegate: scaladsl.Zip[A, B]) { } object Unzip { - def create[A, B](): Unzip[A, B] = - create(null) + def create[A, B](): Unzip[A, B] = create(name = null) def create[A, B](name: String): Unzip[A, B] = new Unzip[A, B](new scaladsl.Unzip[A, B](Option(name))) @@ -390,7 +397,7 @@ object UndefinedSource { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](clazz: Class[T]): UndefinedSource[T] = new UndefinedSource[T](new scaladsl.UndefinedSource[T](None)) + def create[T](clazz: Class[T]): UndefinedSource[T] = create[T]() /** * Create a named `Undefinedsource` vertex with the specified input type. @@ -398,7 +405,7 @@ object UndefinedSource { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): UndefinedSource[T] = new UndefinedSource[T](new scaladsl.UndefinedSource[T](Some(name))) + def create[T](name: String): UndefinedSource[T] = new UndefinedSource[T](new scaladsl.UndefinedSource[T](Option(name))) /** * Create a named `Undefinedsource` vertex with the specified input type. @@ -406,7 +413,7 @@ object UndefinedSource { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](clazz: Class[T], name: String): UndefinedSource[T] = new UndefinedSource[T](new scaladsl.UndefinedSource[T](Some(name))) + def create[T](clazz: Class[T], name: String): UndefinedSource[T] = create[T](name) } /** @@ -425,7 +432,7 @@ object UndefinedSink { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](): UndefinedSink[T] = new UndefinedSink[T](new scaladsl.UndefinedSink[T](None)) + def create[T](): UndefinedSink[T] = create(name = null) /** * Create a new anonymous `Undefinedsink` vertex with the specified input type. @@ -433,7 +440,7 @@ object UndefinedSink { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def create[T](clazz: Class[T]): UndefinedSink[T] = new UndefinedSink[T](new scaladsl.UndefinedSink[T](None)) + def create[T](clazz: Class[T]): UndefinedSink[T] = create[T]() /** * Create a named `Undefinedsink` vertex with the specified input type. @@ -441,7 +448,7 @@ object UndefinedSink { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](name: String): UndefinedSink[T] = new UndefinedSink[T](new scaladsl.UndefinedSink[T](Some(name))) + def create[T](name: String): UndefinedSink[T] = new UndefinedSink[T](new scaladsl.UndefinedSink[T](Option(name))) /** * Create a named `Undefinedsink` vertex with the specified input type. @@ -449,7 +456,7 @@ object UndefinedSink { * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. */ - def create[T](clazz: Class[T], name: String): UndefinedSink[T] = new UndefinedSink[T](new scaladsl.UndefinedSink[T](Some(name))) + def create[T](clazz: Class[T], name: String): UndefinedSink[T] = create[T](name) } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala index 941b576de3..b64f815204 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala @@ -186,14 +186,28 @@ object Balance { * in the `FlowGraph`. This method creates a new instance every time it * is called and those instances are not `equal`. */ - def apply[T]: Balance[T] = new Balance[T](None) + def apply[T]: Balance[T] = new Balance[T](None, waitForAllDownstreams = false) /** * Create a named `Balance` vertex with the specified input type. * Note that a `Balance` with a specific name can only be used at one place (one vertex) * in the `FlowGraph`. Calling this method several times with the same name * returns instances that are `equal`. + * + * If you use `waitForAllDownstreams = true` it will not start emitting + * elements to downstream outputs until all of them have requested at least one element. */ - def apply[T](name: String): Balance[T] = new Balance[T](Some(name)) + def apply[T](name: String, waitForAllDownstreams: Boolean = false): Balance[T] = new Balance[T](Some(name), waitForAllDownstreams) + + /** + * Create a new anonymous `Balance` vertex with the specified input type. + * Note that a `Balance` instance can only be used at one place (one vertex) + * in the `FlowGraph`. This method creates a new instance every time it + * is called and those instances are not `equal`. + * + * If you use `waitForAllDownstreams = true` it will not start emitting + * elements to downstream outputs until all of them have requested at least one element. + */ + def apply[T](waitForAllDownstreams: Boolean): Balance[T] = new Balance[T](None, waitForAllDownstreams) } /** @@ -201,16 +215,16 @@ object Balance { * one of the other streams. It will not shutdown until the subscriptions for at least * two downstream subscribers have been established. */ -final class Balance[T](override val name: Option[String]) extends FlowGraphInternal.InternalVertex with Junction[T] { +final class Balance[T](override val name: Option[String], val waitForAllDownstreams: Boolean) extends FlowGraphInternal.InternalVertex with Junction[T] { override private[akka] def vertex = this override def minimumInputCount: Int = 1 override def maximumInputCount: Int = 1 override def minimumOutputCount: Int = 2 override def maximumOutputCount: Int = Int.MaxValue - override private[akka] def astNode = Ast.Balance + override private[akka] val astNode = Ast.Balance(waitForAllDownstreams) - final override private[scaladsl] def newInstance() = new Balance[T](None) + final override private[scaladsl] def newInstance() = new Balance[T](None, waitForAllDownstreams) } object Zip {