+str #19176 Added Partition Graph Stage with tests
This commit is contained in:
parent
4de3f63b93
commit
c15f4d15f4
3 changed files with 313 additions and 1 deletions
|
|
@ -0,0 +1,175 @@
|
|||
/**
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.stream.scaladsl
|
||||
|
||||
import akka.stream.testkit._
|
||||
import akka.stream.testkit.scaladsl.TestSink
|
||||
import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings, ClosedShape }
|
||||
import akka.stream.testkit.Utils._
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class GraphPartitionSpec extends AkkaSpec {
|
||||
|
||||
val settings = ActorMaterializerSettings(system)
|
||||
.withInputBuffer(initialSize = 2, maxSize = 16)
|
||||
|
||||
implicit val materializer = ActorMaterializer(settings)
|
||||
|
||||
"A partition" must {
|
||||
import GraphDSL.Implicits._
|
||||
|
||||
"partition to three subscribers" in assertAllStagesStopped {
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
val c2 = TestSubscriber.probe[Int]()
|
||||
val c3 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](3, {
|
||||
case g if (g > 3) ⇒ 0
|
||||
case l if (l < 3) ⇒ 1
|
||||
case e if (e == 3) ⇒ 2
|
||||
}))
|
||||
Source(List(1, 2, 3, 4, 5)) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||
partition.out(2) ~> Sink.fromSubscriber(c3)
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c2.request(2)
|
||||
c1.request(2)
|
||||
c3.request(1)
|
||||
c2.expectNext(1)
|
||||
c2.expectNext(2)
|
||||
c3.expectNext(3)
|
||||
c1.expectNext(4)
|
||||
c1.expectNext(5)
|
||||
c1.expectComplete()
|
||||
c2.expectComplete()
|
||||
c3.expectComplete()
|
||||
}
|
||||
|
||||
"complete stage after upstream completes" in assertAllStagesStopped {
|
||||
val c1 = TestSubscriber.probe[String]()
|
||||
val c2 = TestSubscriber.probe[String]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[String](2, {
|
||||
case s if (s.length > 4) ⇒ 0
|
||||
case _ ⇒ 1
|
||||
}))
|
||||
Source(List("this", "is", "just", "another", "test")) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c1.request(1)
|
||||
c2.request(4)
|
||||
c1.expectNext("another")
|
||||
c2.expectNext("this")
|
||||
c2.expectNext("is")
|
||||
c2.expectNext("just")
|
||||
c2.expectNext("test")
|
||||
c1.expectComplete()
|
||||
c2.expectComplete()
|
||||
|
||||
}
|
||||
|
||||
"remember first pull even though first element targeted another out" in assertAllStagesStopped {
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
val c2 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||
Source(List(6, 3)) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c1.request(1)
|
||||
c1.expectNoMsg(1.seconds)
|
||||
c2.request(1)
|
||||
c2.expectNext(6)
|
||||
c1.expectNext(3)
|
||||
c1.expectComplete()
|
||||
c2.expectComplete()
|
||||
}
|
||||
|
||||
"cancel upstream when downstreams cancel" in assertAllStagesStopped {
|
||||
val p1 = TestPublisher.probe[Int]()
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
val c2 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||
Source.fromPublisher(p1.getPublisher) ~> partition.in
|
||||
partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2)
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
val p1Sub = p1.expectSubscription()
|
||||
val sub1 = c1.expectSubscription()
|
||||
val sub2 = c2.expectSubscription()
|
||||
sub1.request(3)
|
||||
sub2.request(3)
|
||||
p1Sub.sendNext(1)
|
||||
p1Sub.sendNext(8)
|
||||
c1.expectNext(1)
|
||||
c2.expectNext(8)
|
||||
p1Sub.sendNext(2)
|
||||
c1.expectNext(2)
|
||||
sub1.cancel()
|
||||
sub2.cancel()
|
||||
p1Sub.expectCancellation()
|
||||
}
|
||||
|
||||
"work with merge" in assertAllStagesStopped {
|
||||
val s = Sink.seq[Int]
|
||||
val input = Set(5, 2, 9, 1, 1, 1, 10)
|
||||
|
||||
val g = RunnableGraph.fromGraph(GraphDSL.create(s) { implicit b ⇒
|
||||
sink ⇒
|
||||
val partition = b.add(Partition[Int](2, { case l if l < 4 ⇒ 0; case _ ⇒ 1 }))
|
||||
val merge = b.add(Merge[Int](2))
|
||||
Source(input) ~> partition.in
|
||||
partition.out(0) ~> merge.in(0)
|
||||
partition.out(1) ~> merge.in(1)
|
||||
merge.out ~> sink.in
|
||||
|
||||
ClosedShape
|
||||
})
|
||||
|
||||
val result = Await.result(g.run(), 300.millis)
|
||||
|
||||
result.toSet should be(input)
|
||||
|
||||
}
|
||||
|
||||
"stage completion is waiting for pending output" in assertAllStagesStopped {
|
||||
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
val c2 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](2, { case l if l < 6 ⇒ 0; case _ ⇒ 1 }))
|
||||
Source(List(6)) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c1.request(1)
|
||||
c1.expectNoMsg(1.second)
|
||||
c2.request(1)
|
||||
c2.expectNext(6)
|
||||
c1.expectComplete()
|
||||
c2.expectComplete()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
package akka.stream.javadsl
|
||||
|
||||
import akka.stream._
|
||||
import akka.japi.Pair
|
||||
import akka.japi.{ Pair, function }
|
||||
import scala.annotation.unchecked.uncheckedVariance
|
||||
import akka.stream.impl.ConstantFun
|
||||
|
||||
|
|
@ -135,6 +135,40 @@ object Broadcast {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
|
||||
* to the partitioner function applied to the element
|
||||
*
|
||||
* '''Emits when''' all of the outputs stops backpressuring and there is an input element available
|
||||
*
|
||||
* '''Backpressures when''' one of the outputs backpressure
|
||||
*
|
||||
* '''Completes when''' upstream completes
|
||||
*
|
||||
* '''Cancels when'''
|
||||
* when one of the downstreams cancel
|
||||
*/
|
||||
object Partition {
|
||||
/**
|
||||
* Create a new `Partition` stage with the specified input type.
|
||||
*
|
||||
* @param outputCount number of output ports
|
||||
* @param partitioner function deciding which output each element will be targeted
|
||||
*/
|
||||
def create[T](outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
|
||||
scaladsl.Partition(outputCount, partitioner = (t: T) ⇒ partitioner.apply(t))
|
||||
|
||||
/**
|
||||
* Create a new `Partition` stage with the specified input type.
|
||||
*
|
||||
* @param outputCount number of output ports
|
||||
* @param partitioner function deciding which output each element will be targeted
|
||||
*/
|
||||
def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
|
||||
create(outputCount, partitioner)
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer.
|
||||
* It will not shutdown until the subscriptions for at least
|
||||
|
|
|
|||
|
|
@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
|
|||
|
||||
}
|
||||
|
||||
object Partition {
|
||||
|
||||
/**
|
||||
* Create a new `Partition` stage with the specified input type.
|
||||
*
|
||||
* @param outputPorts number of output ports
|
||||
* @param partitioner function deciding which output each element will be targeted
|
||||
*/
|
||||
def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner)
|
||||
}
|
||||
|
||||
/**
|
||||
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
|
||||
* to the partitioner function applied to the element
|
||||
*
|
||||
* '''Emits when''' emits when an element is available from the input and the chosen output has demand
|
||||
*
|
||||
* '''Backpressures when''' the currently chosen output back-pressures
|
||||
*
|
||||
* '''Completes when''' upstream completes and no output is pending
|
||||
*
|
||||
* '''Cancels when'''
|
||||
* when all downstreams cancel
|
||||
*/
|
||||
|
||||
final class Partition[T](outputPorts: Int, partitioner: T ⇒ Int) extends GraphStage[UniformFanOutShape[T, T]] {
|
||||
|
||||
val in: Inlet[T] = Inlet[T]("Partition.in")
|
||||
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i))
|
||||
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||
private var outPendingElem: Any = null
|
||||
private var outPendingIdx: Int = _
|
||||
private var downstreamRunning = outputPorts
|
||||
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush() = {
|
||||
val elem = grab(in)
|
||||
val idx = partitioner(elem)
|
||||
if (idx < 0 || idx >= outputPorts)
|
||||
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
|
||||
else if (!isClosed(out(idx))) {
|
||||
if (isAvailable(out(idx))) {
|
||||
push(out(idx), elem)
|
||||
if (out.exists(isAvailable(_)))
|
||||
pull(in)
|
||||
} else {
|
||||
outPendingElem = elem
|
||||
outPendingIdx = idx
|
||||
}
|
||||
|
||||
} else if (out.exists(isAvailable(_)))
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (outPendingElem == null)
|
||||
completeStage()
|
||||
}
|
||||
})
|
||||
|
||||
out.zipWithIndex.foreach {
|
||||
case (o, idx) ⇒
|
||||
setHandler(o, new OutHandler {
|
||||
override def onPull() = {
|
||||
|
||||
if (outPendingElem != null) {
|
||||
val elem = outPendingElem.asInstanceOf[T]
|
||||
if (idx == outPendingIdx) {
|
||||
push(o, elem)
|
||||
outPendingElem = null
|
||||
if (!isClosed(in)) {
|
||||
if (!hasBeenPulled(in)) {
|
||||
pull(in)
|
||||
}
|
||||
} else
|
||||
completeStage()
|
||||
}
|
||||
} else if (!hasBeenPulled(in))
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
downstreamRunning -= 1
|
||||
if (downstreamRunning == 0)
|
||||
completeStage()
|
||||
else if (outPendingElem != null) {
|
||||
if (idx == outPendingIdx) {
|
||||
outPendingElem = null
|
||||
if (!hasBeenPulled(in))
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = s"Partition($outputPorts)"
|
||||
|
||||
}
|
||||
|
||||
object Balance {
|
||||
/**
|
||||
* Create a new `Balance` with the specified number of output ports.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue