From 26e2dcb8579c3e5f97a357061f047a2df7122b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andre=CC=81n?= Date: Thu, 21 Jan 2016 14:43:26 +0100 Subject: [PATCH] =str #19549 #19257 Allow Broadcast(1) and Merge(1) --- .../akka/stream/scaladsl/GraphBroadcastSpec.scala | 13 +++++++++++++ .../scala/akka/stream/scaladsl/GraphMergeSpec.scala | 13 +++++++++++++ .../src/main/scala/akka/stream/scaladsl/Graph.scala | 6 ++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala index 7f9bc8b454..528d853212 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphBroadcastSpec.scala @@ -49,6 +49,19 @@ class GraphBroadcastSpec extends AkkaSpec { c2.expectComplete() } + "work with one-way broadcast" in assertAllStagesStopped { + val result = Source.fromGraph(GraphDSL.create() { implicit b ⇒ + val broadcast = b.add(Broadcast[Int](1)) + val source = b.add(Source(1 to 3)) + + source ~> broadcast.in + + SourceShape(broadcast.out(0)) + }).runFold(Seq[Int]())(_ :+ _) + + Await.result(result, 3.seconds) should ===(Seq(1, 2, 3)) + } + "work with n-way broadcast" in assertAllStagesStopped { val headSink = Sink.head[Seq[Int]] diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala index 91eb669251..f0d9ee16ba 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala @@ -5,6 +5,7 @@ package akka.stream.scaladsl import akka.stream._ +import scala.concurrent.Await import scala.concurrent.duration._ import akka.stream.testkit._ @@ -58,6 +59,18 @@ class GraphMergeSpec extends TwoStreamsSetup { probe.expectComplete() } + "work with 1-way merge" in { + val result = Source.fromGraph(GraphDSL.create() { implicit b ⇒ + val merge = b.add(Merge[Int](1)) + val source = b.add(Source(1 to 3)) + + source ~> merge.in(0) + SourceShape(merge.out) + }).runFold(Seq[Int]())(_ :+ _) + + Await.result(result, 3.seconds) should ===(Seq(1, 2, 3)) + } + "work with n-way merge" in { val source1 = Source(List(1)) val source2 = Source(List(2)) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 9da490b32f..9fa3e5b684 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -39,7 +39,8 @@ object Merge { * '''Cancels when''' downstream cancels */ final class Merge[T] private (val inputPorts: Int, val eagerComplete: Boolean) extends GraphStage[UniformFanInShape[T, T]] { - require(inputPorts > 1, "A Merge must have more than 1 input port") + // one input might seem counter intuitive but saves us from special handling in other places + require(inputPorts >= 1, "A Merge must have one or more input ports") val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i ⇒ Inlet[T]("Merge.in" + i)) val out: Outlet[T] = Outlet[T]("Merge.out") @@ -395,7 +396,8 @@ object Broadcast { * */ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) extends GraphStage[UniformFanOutShape[T, T]] { - require(outputPorts > 1, "A Broadcast must have more than 1 output ports") + // one input might seem counter intuitive but saves us from special handling in other places + require(outputPorts >= 1, "A Broadcast must have one or more output ports") val in: Inlet[T] = Inlet[T]("Broadast.in") val out: immutable.IndexedSeq[Outlet[T]] = Vector.tabulate(outputPorts)(i ⇒ Outlet[T]("Broadcast.out" + i)) override def initialAttributes = DefaultAttributes.broadcast