diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala new file mode 100644 index 0000000000..93520e13e0 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala @@ -0,0 +1,175 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings, ClosedShape } +import akka.stream.testkit.Utils._ +import scala.concurrent.Await +import scala.concurrent.duration._ + +class GraphPartitionSpec extends AkkaSpec { + + val settings = ActorMaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 16) + + implicit val materializer = ActorMaterializer(settings) + + "A partition" must { + import GraphDSL.Implicits._ + + "partition to three subscribers" in assertAllStagesStopped { + val c1 = TestSubscriber.probe[Int]() + val c2 = TestSubscriber.probe[Int]() + val c3 = TestSubscriber.probe[Int]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(Partition[Int](3, { + case g if (g > 3) ⇒ 0 + case l if (l < 3) ⇒ 1 + case e if (e == 3) ⇒ 2 + })) + Source(List(1, 2, 3, 4, 5)) ~> partition.in + partition.out(0) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Sink.fromSubscriber(c2) + partition.out(2) ~> Sink.fromSubscriber(c3) + ClosedShape + }).run() + + c2.request(2) + c1.request(2) + c3.request(1) + c2.expectNext(1) + c2.expectNext(2) + c3.expectNext(3) + c1.expectNext(4) + c1.expectNext(5) + c1.expectComplete() + c2.expectComplete() + c3.expectComplete() + } + + "complete stage after upstream completes" in assertAllStagesStopped { + val c1 = TestSubscriber.probe[String]() + val c2 = TestSubscriber.probe[String]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(Partition[String](2, { + case s if (s.length > 4) ⇒ 0 + case _ ⇒ 1 + })) + Source(List("this", "is", "just", "another", "test")) ~> partition.in + partition.out(0) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Sink.fromSubscriber(c2) + ClosedShape + }).run() + + c1.request(1) + c2.request(4) + c1.expectNext("another") + c2.expectNext("this") + c2.expectNext("is") + c2.expectNext("just") + c2.expectNext("test") + c1.expectComplete() + c2.expectComplete() + + } + + "remember first pull even though first element targeted another out" in assertAllStagesStopped { + val c1 = TestSubscriber.probe[Int]() + val c2 = TestSubscriber.probe[Int]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 })) + Source(List(6, 3)) ~> partition.in + partition.out(0) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Sink.fromSubscriber(c2) + ClosedShape + }).run() + + c1.request(1) + c1.expectNoMsg(1.seconds) + c2.request(1) + c2.expectNext(6) + c1.expectNext(3) + c1.expectComplete() + c2.expectComplete() + } + + "cancel upstream when downstreams cancel" in assertAllStagesStopped { + val p1 = TestPublisher.probe[Int]() + val c1 = TestSubscriber.probe[Int]() + val c2 = TestSubscriber.probe[Int]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 })) + Source.fromPublisher(p1.getPublisher) ~> partition.in + partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2) + ClosedShape + }).run() + + val p1Sub = p1.expectSubscription() + val sub1 = c1.expectSubscription() + val sub2 = c2.expectSubscription() + sub1.request(3) + sub2.request(3) + p1Sub.sendNext(1) + p1Sub.sendNext(8) + c1.expectNext(1) + c2.expectNext(8) + p1Sub.sendNext(2) + c1.expectNext(2) + sub1.cancel() + sub2.cancel() + p1Sub.expectCancellation() + } + + "work with merge" in assertAllStagesStopped { + val s = Sink.seq[Int] + val input = Set(5, 2, 9, 1, 1, 1, 10) + + val g = RunnableGraph.fromGraph(GraphDSL.create(s) { implicit b ⇒ + sink ⇒ + val partition = b.add(Partition[Int](2, { case l if l < 4 ⇒ 0; case _ ⇒ 1 })) + val merge = b.add(Merge[Int](2)) + Source(input) ~> partition.in + partition.out(0) ~> merge.in(0) + partition.out(1) ~> merge.in(1) + merge.out ~> sink.in + + ClosedShape + }) + + val result = Await.result(g.run(), 300.millis) + + result.toSet should be(input) + + } + + "stage completion is waiting for pending output" in assertAllStagesStopped { + + val c1 = TestSubscriber.probe[Int]() + val c2 = TestSubscriber.probe[Int]() + + RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒ + val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 })) + Source(List(6)) ~> partition.in + partition.out(0) ~> Sink.fromSubscriber(c1) + partition.out(1) ~> Sink.fromSubscriber(c2) + ClosedShape + }).run() + + c1.request(1) + c1.expectNoMsg(1.second) + c2.request(1) + c2.expectNext(6) + c1.expectComplete() + c2.expectComplete() + } + + } +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala index e8a49ab2a8..fa11bcba2e 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala @@ -4,7 +4,7 @@ package akka.stream.javadsl import akka.stream._ -import akka.japi.Pair +import akka.japi.{ Pair, function } import scala.annotation.unchecked.uncheckedVariance import akka.stream.impl.ConstantFun @@ -135,6 +135,40 @@ object Broadcast { } +/** + * Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according + * to the partitioner function applied to the element + * + * '''Emits when''' all of the outputs stops backpressuring and there is an input element available + * + * '''Backpressures when''' one of the outputs backpressure + * + * '''Completes when''' upstream completes + * + * '''Cancels when''' + * when one of the downstreams cancel + */ +object Partition { + /** + * Create a new `Partition` stage with the specified input type. + * + * @param outputCount number of output ports + * @param partitioner function deciding which output each element will be targeted + */ + def create[T](outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] = + scaladsl.Partition(outputCount, partitioner = (t: T) ⇒ partitioner.apply(t)) + + /** + * Create a new `Partition` stage with the specified input type. + * + * @param outputCount number of output ports + * @param partitioner function deciding which output each element will be targeted + */ + def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] = + create(outputCount, partitioner) + +} + /** * Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer. * It will not shutdown until the subscriptions for at least diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index d5ec09c67f..12449984ad 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext } +object Partition { + + /** + * Create a new `Partition` stage with the specified input type. + * + * @param outputPorts number of output ports + * @param partitioner function deciding which output each element will be targeted + */ + def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner) +} + +/** + * Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according + * to the partitioner function applied to the element + * + * '''Emits when''' emits when an element is available from the input and the chosen output has demand + * + * '''Backpressures when''' the currently chosen output back-pressures + * + * '''Completes when''' upstream completes and no output is pending + * + * '''Cancels when''' + * when all downstreams cancel + */ + +final class Partition[T](outputPorts: Int, partitioner: T ⇒ Int) extends GraphStage[UniformFanOutShape[T, T]] { + + val in: Inlet[T] = Inlet[T]("Partition.in") + val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i)) + override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + private var outPendingElem: Any = null + private var outPendingIdx: Int = _ + private var downstreamRunning = outputPorts + + setHandler(in, new InHandler { + override def onPush() = { + val elem = grab(in) + val idx = partitioner(elem) + if (idx < 0 || idx >= outputPorts) + failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem].")) + else if (!isClosed(out(idx))) { + if (isAvailable(out(idx))) { + push(out(idx), elem) + if (out.exists(isAvailable(_))) + pull(in) + } else { + outPendingElem = elem + outPendingIdx = idx + } + + } else if (out.exists(isAvailable(_))) + pull(in) + } + + override def onUpstreamFinish(): Unit = { + if (outPendingElem == null) + completeStage() + } + }) + + out.zipWithIndex.foreach { + case (o, idx) ⇒ + setHandler(o, new OutHandler { + override def onPull() = { + + if (outPendingElem != null) { + val elem = outPendingElem.asInstanceOf[T] + if (idx == outPendingIdx) { + push(o, elem) + outPendingElem = null + if (!isClosed(in)) { + if (!hasBeenPulled(in)) { + pull(in) + } + } else + completeStage() + } + } else if (!hasBeenPulled(in)) + pull(in) + } + + override def onDownstreamFinish(): Unit = { + downstreamRunning -= 1 + if (downstreamRunning == 0) + completeStage() + else if (outPendingElem != null) { + if (idx == outPendingIdx) { + outPendingElem = null + if (!hasBeenPulled(in)) + pull(in) + } + } + } + }) + } + } + + override def toString = s"Partition($outputPorts)" + +} + object Balance { /** * Create a new `Balance` with the specified number of output ports.