+str #19176 Added Partition Graph Stage with tests
This commit is contained in:
parent
4de3f63b93
commit
c15f4d15f4
3 changed files with 313 additions and 1 deletions
|
|
@ -0,0 +1,175 @@
|
||||||
|
/**
|
||||||
|
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||||
|
*/
|
||||||
|
package akka.stream.scaladsl
|
||||||
|
|
||||||
|
import akka.stream.testkit._
|
||||||
|
import akka.stream.testkit.scaladsl.TestSink
|
||||||
|
import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings, ClosedShape }
|
||||||
|
import akka.stream.testkit.Utils._
|
||||||
|
import scala.concurrent.Await
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
|
||||||
|
class GraphPartitionSpec extends AkkaSpec {
|
||||||
|
|
||||||
|
val settings = ActorMaterializerSettings(system)
|
||||||
|
.withInputBuffer(initialSize = 2, maxSize = 16)
|
||||||
|
|
||||||
|
implicit val materializer = ActorMaterializer(settings)
|
||||||
|
|
||||||
|
"A partition" must {
|
||||||
|
import GraphDSL.Implicits._
|
||||||
|
|
||||||
|
"partition to three subscribers" in assertAllStagesStopped {
|
||||||
|
val c1 = TestSubscriber.probe[Int]()
|
||||||
|
val c2 = TestSubscriber.probe[Int]()
|
||||||
|
val c3 = TestSubscriber.probe[Int]()
|
||||||
|
|
||||||
|
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val partition = b.add(Partition[Int](3, {
|
||||||
|
case g if (g > 3) ⇒ 0
|
||||||
|
case l if (l < 3) ⇒ 1
|
||||||
|
case e if (e == 3) ⇒ 2
|
||||||
|
}))
|
||||||
|
Source(List(1, 2, 3, 4, 5)) ~> partition.in
|
||||||
|
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||||
|
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||||
|
partition.out(2) ~> Sink.fromSubscriber(c3)
|
||||||
|
ClosedShape
|
||||||
|
}).run()
|
||||||
|
|
||||||
|
c2.request(2)
|
||||||
|
c1.request(2)
|
||||||
|
c3.request(1)
|
||||||
|
c2.expectNext(1)
|
||||||
|
c2.expectNext(2)
|
||||||
|
c3.expectNext(3)
|
||||||
|
c1.expectNext(4)
|
||||||
|
c1.expectNext(5)
|
||||||
|
c1.expectComplete()
|
||||||
|
c2.expectComplete()
|
||||||
|
c3.expectComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
"complete stage after upstream completes" in assertAllStagesStopped {
|
||||||
|
val c1 = TestSubscriber.probe[String]()
|
||||||
|
val c2 = TestSubscriber.probe[String]()
|
||||||
|
|
||||||
|
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val partition = b.add(Partition[String](2, {
|
||||||
|
case s if (s.length > 4) ⇒ 0
|
||||||
|
case _ ⇒ 1
|
||||||
|
}))
|
||||||
|
Source(List("this", "is", "just", "another", "test")) ~> partition.in
|
||||||
|
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||||
|
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||||
|
ClosedShape
|
||||||
|
}).run()
|
||||||
|
|
||||||
|
c1.request(1)
|
||||||
|
c2.request(4)
|
||||||
|
c1.expectNext("another")
|
||||||
|
c2.expectNext("this")
|
||||||
|
c2.expectNext("is")
|
||||||
|
c2.expectNext("just")
|
||||||
|
c2.expectNext("test")
|
||||||
|
c1.expectComplete()
|
||||||
|
c2.expectComplete()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
"remember first pull even though first element targeted another out" in assertAllStagesStopped {
|
||||||
|
val c1 = TestSubscriber.probe[Int]()
|
||||||
|
val c2 = TestSubscriber.probe[Int]()
|
||||||
|
|
||||||
|
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||||
|
Source(List(6, 3)) ~> partition.in
|
||||||
|
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||||
|
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||||
|
ClosedShape
|
||||||
|
}).run()
|
||||||
|
|
||||||
|
c1.request(1)
|
||||||
|
c1.expectNoMsg(1.seconds)
|
||||||
|
c2.request(1)
|
||||||
|
c2.expectNext(6)
|
||||||
|
c1.expectNext(3)
|
||||||
|
c1.expectComplete()
|
||||||
|
c2.expectComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
"cancel upstream when downstreams cancel" in assertAllStagesStopped {
|
||||||
|
val p1 = TestPublisher.probe[Int]()
|
||||||
|
val c1 = TestSubscriber.probe[Int]()
|
||||||
|
val c2 = TestSubscriber.probe[Int]()
|
||||||
|
|
||||||
|
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||||
|
Source.fromPublisher(p1.getPublisher) ~> partition.in
|
||||||
|
partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1)
|
||||||
|
partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2)
|
||||||
|
ClosedShape
|
||||||
|
}).run()
|
||||||
|
|
||||||
|
val p1Sub = p1.expectSubscription()
|
||||||
|
val sub1 = c1.expectSubscription()
|
||||||
|
val sub2 = c2.expectSubscription()
|
||||||
|
sub1.request(3)
|
||||||
|
sub2.request(3)
|
||||||
|
p1Sub.sendNext(1)
|
||||||
|
p1Sub.sendNext(8)
|
||||||
|
c1.expectNext(1)
|
||||||
|
c2.expectNext(8)
|
||||||
|
p1Sub.sendNext(2)
|
||||||
|
c1.expectNext(2)
|
||||||
|
sub1.cancel()
|
||||||
|
sub2.cancel()
|
||||||
|
p1Sub.expectCancellation()
|
||||||
|
}
|
||||||
|
|
||||||
|
"work with merge" in assertAllStagesStopped {
|
||||||
|
val s = Sink.seq[Int]
|
||||||
|
val input = Set(5, 2, 9, 1, 1, 1, 10)
|
||||||
|
|
||||||
|
val g = RunnableGraph.fromGraph(GraphDSL.create(s) { implicit b ⇒
|
||||||
|
sink ⇒
|
||||||
|
val partition = b.add(Partition[Int](2, { case l if l < 4 ⇒ 0; case _ ⇒ 1 }))
|
||||||
|
val merge = b.add(Merge[Int](2))
|
||||||
|
Source(input) ~> partition.in
|
||||||
|
partition.out(0) ~> merge.in(0)
|
||||||
|
partition.out(1) ~> merge.in(1)
|
||||||
|
merge.out ~> sink.in
|
||||||
|
|
||||||
|
ClosedShape
|
||||||
|
})
|
||||||
|
|
||||||
|
val result = Await.result(g.run(), 300.millis)
|
||||||
|
|
||||||
|
result.toSet should be(input)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
"stage completion is waiting for pending output" in assertAllStagesStopped {
|
||||||
|
|
||||||
|
val c1 = TestSubscriber.probe[Int]()
|
||||||
|
val c2 = TestSubscriber.probe[Int]()
|
||||||
|
|
||||||
|
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||||
|
Source(List(6)) ~> partition.in
|
||||||
|
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||||
|
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||||
|
ClosedShape
|
||||||
|
}).run()
|
||||||
|
|
||||||
|
c1.request(1)
|
||||||
|
c1.expectNoMsg(1.second)
|
||||||
|
c2.request(1)
|
||||||
|
c2.expectNext(6)
|
||||||
|
c1.expectComplete()
|
||||||
|
c2.expectComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
package akka.stream.javadsl
|
package akka.stream.javadsl
|
||||||
|
|
||||||
import akka.stream._
|
import akka.stream._
|
||||||
import akka.japi.Pair
|
import akka.japi.{ Pair, function }
|
||||||
import scala.annotation.unchecked.uncheckedVariance
|
import scala.annotation.unchecked.uncheckedVariance
|
||||||
import akka.stream.impl.ConstantFun
|
import akka.stream.impl.ConstantFun
|
||||||
|
|
||||||
|
|
@ -135,6 +135,40 @@ object Broadcast {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
|
||||||
|
* to the partitioner function applied to the element
|
||||||
|
*
|
||||||
|
* '''Emits when''' all of the outputs stops backpressuring and there is an input element available
|
||||||
|
*
|
||||||
|
* '''Backpressures when''' one of the outputs backpressure
|
||||||
|
*
|
||||||
|
* '''Completes when''' upstream completes
|
||||||
|
*
|
||||||
|
* '''Cancels when'''
|
||||||
|
* when one of the downstreams cancel
|
||||||
|
*/
|
||||||
|
object Partition {
|
||||||
|
/**
|
||||||
|
* Create a new `Partition` stage with the specified input type.
|
||||||
|
*
|
||||||
|
* @param outputCount number of output ports
|
||||||
|
* @param partitioner function deciding which output each element will be targeted
|
||||||
|
*/
|
||||||
|
def create[T](outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
|
||||||
|
scaladsl.Partition(outputCount, partitioner = (t: T) ⇒ partitioner.apply(t))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new `Partition` stage with the specified input type.
|
||||||
|
*
|
||||||
|
* @param outputCount number of output ports
|
||||||
|
* @param partitioner function deciding which output each element will be targeted
|
||||||
|
*/
|
||||||
|
def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
|
||||||
|
create(outputCount, partitioner)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer.
|
* Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer.
|
||||||
* It will not shutdown until the subscriptions for at least
|
* It will not shutdown until the subscriptions for at least
|
||||||
|
|
|
||||||
|
|
@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object Partition {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new `Partition` stage with the specified input type.
|
||||||
|
*
|
||||||
|
* @param outputPorts number of output ports
|
||||||
|
* @param partitioner function deciding which output each element will be targeted
|
||||||
|
*/
|
||||||
|
def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
|
||||||
|
* to the partitioner function applied to the element
|
||||||
|
*
|
||||||
|
* '''Emits when''' emits when an element is available from the input and the chosen output has demand
|
||||||
|
*
|
||||||
|
* '''Backpressures when''' the currently chosen output back-pressures
|
||||||
|
*
|
||||||
|
* '''Completes when''' upstream completes and no output is pending
|
||||||
|
*
|
||||||
|
* '''Cancels when'''
|
||||||
|
* when all downstreams cancel
|
||||||
|
*/
|
||||||
|
|
||||||
|
final class Partition[T](outputPorts: Int, partitioner: T ⇒ Int) extends GraphStage[UniformFanOutShape[T, T]] {
|
||||||
|
|
||||||
|
val in: Inlet[T] = Inlet[T]("Partition.in")
|
||||||
|
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i))
|
||||||
|
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
|
||||||
|
|
||||||
|
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||||
|
private var outPendingElem: Any = null
|
||||||
|
private var outPendingIdx: Int = _
|
||||||
|
private var downstreamRunning = outputPorts
|
||||||
|
|
||||||
|
setHandler(in, new InHandler {
|
||||||
|
override def onPush() = {
|
||||||
|
val elem = grab(in)
|
||||||
|
val idx = partitioner(elem)
|
||||||
|
if (idx < 0 || idx >= outputPorts)
|
||||||
|
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
|
||||||
|
else if (!isClosed(out(idx))) {
|
||||||
|
if (isAvailable(out(idx))) {
|
||||||
|
push(out(idx), elem)
|
||||||
|
if (out.exists(isAvailable(_)))
|
||||||
|
pull(in)
|
||||||
|
} else {
|
||||||
|
outPendingElem = elem
|
||||||
|
outPendingIdx = idx
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (out.exists(isAvailable(_)))
|
||||||
|
pull(in)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def onUpstreamFinish(): Unit = {
|
||||||
|
if (outPendingElem == null)
|
||||||
|
completeStage()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
out.zipWithIndex.foreach {
|
||||||
|
case (o, idx) ⇒
|
||||||
|
setHandler(o, new OutHandler {
|
||||||
|
override def onPull() = {
|
||||||
|
|
||||||
|
if (outPendingElem != null) {
|
||||||
|
val elem = outPendingElem.asInstanceOf[T]
|
||||||
|
if (idx == outPendingIdx) {
|
||||||
|
push(o, elem)
|
||||||
|
outPendingElem = null
|
||||||
|
if (!isClosed(in)) {
|
||||||
|
if (!hasBeenPulled(in)) {
|
||||||
|
pull(in)
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
completeStage()
|
||||||
|
}
|
||||||
|
} else if (!hasBeenPulled(in))
|
||||||
|
pull(in)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def onDownstreamFinish(): Unit = {
|
||||||
|
downstreamRunning -= 1
|
||||||
|
if (downstreamRunning == 0)
|
||||||
|
completeStage()
|
||||||
|
else if (outPendingElem != null) {
|
||||||
|
if (idx == outPendingIdx) {
|
||||||
|
outPendingElem = null
|
||||||
|
if (!hasBeenPulled(in))
|
||||||
|
pull(in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString = s"Partition($outputPorts)"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
object Balance {
|
object Balance {
|
||||||
/**
|
/**
|
||||||
* Create a new `Balance` with the specified number of output ports.
|
* Create a new `Balance` with the specified number of output ports.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue