+str #19176 Added Partition Graph Stage with tests

This commit is contained in:
israel 2016-01-14 18:08:18 +02:00
parent 4de3f63b93
commit c15f4d15f4
3 changed files with 313 additions and 1 deletions

View file

@ -0,0 +1,175 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.scaladsl
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink
import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings, ClosedShape }
import akka.stream.testkit.Utils._
import scala.concurrent.Await
import scala.concurrent.duration._
class GraphPartitionSpec extends AkkaSpec {
val settings = ActorMaterializerSettings(system)
.withInputBuffer(initialSize = 2, maxSize = 16)
implicit val materializer = ActorMaterializer(settings)
"A partition" must {
import GraphDSL.Implicits._
"partition to three subscribers" in assertAllStagesStopped {
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
val c3 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[Int](3, {
case g if (g > 3) 0
case l if (l < 3) 1
case e if (e == 3) 2
}))
Source(List(1, 2, 3, 4, 5)) ~> partition.in
partition.out(0) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Sink.fromSubscriber(c2)
partition.out(2) ~> Sink.fromSubscriber(c3)
ClosedShape
}).run()
c2.request(2)
c1.request(2)
c3.request(1)
c2.expectNext(1)
c2.expectNext(2)
c3.expectNext(3)
c1.expectNext(4)
c1.expectNext(5)
c1.expectComplete()
c2.expectComplete()
c3.expectComplete()
}
"complete stage after upstream completes" in assertAllStagesStopped {
val c1 = TestSubscriber.probe[String]()
val c2 = TestSubscriber.probe[String]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[String](2, {
case s if (s.length > 4) 0
case _ 1
}))
Source(List("this", "is", "just", "another", "test")) ~> partition.in
partition.out(0) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Sink.fromSubscriber(c2)
ClosedShape
}).run()
c1.request(1)
c2.request(4)
c1.expectNext("another")
c2.expectNext("this")
c2.expectNext("is")
c2.expectNext("just")
c2.expectNext("test")
c1.expectComplete()
c2.expectComplete()
}
"remember first pull even though first element targeted another out" in assertAllStagesStopped {
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[Int](2, { case l if l < 6 0; case _ 1 }))
Source(List(6, 3)) ~> partition.in
partition.out(0) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Sink.fromSubscriber(c2)
ClosedShape
}).run()
c1.request(1)
c1.expectNoMsg(1.seconds)
c2.request(1)
c2.expectNext(6)
c1.expectNext(3)
c1.expectComplete()
c2.expectComplete()
}
"cancel upstream when downstreams cancel" in assertAllStagesStopped {
val p1 = TestPublisher.probe[Int]()
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[Int](2, { case l if l < 6 0; case _ 1 }))
Source.fromPublisher(p1.getPublisher) ~> partition.in
partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2)
ClosedShape
}).run()
val p1Sub = p1.expectSubscription()
val sub1 = c1.expectSubscription()
val sub2 = c2.expectSubscription()
sub1.request(3)
sub2.request(3)
p1Sub.sendNext(1)
p1Sub.sendNext(8)
c1.expectNext(1)
c2.expectNext(8)
p1Sub.sendNext(2)
c1.expectNext(2)
sub1.cancel()
sub2.cancel()
p1Sub.expectCancellation()
}
"work with merge" in assertAllStagesStopped {
val s = Sink.seq[Int]
val input = Set(5, 2, 9, 1, 1, 1, 10)
val g = RunnableGraph.fromGraph(GraphDSL.create(s) { implicit b
sink
val partition = b.add(Partition[Int](2, { case l if l < 4 0; case _ 1 }))
val merge = b.add(Merge[Int](2))
Source(input) ~> partition.in
partition.out(0) ~> merge.in(0)
partition.out(1) ~> merge.in(1)
merge.out ~> sink.in
ClosedShape
})
val result = Await.result(g.run(), 300.millis)
result.toSet should be(input)
}
"stage completion is waiting for pending output" in assertAllStagesStopped {
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[Int](2, { case l if l < 6 0; case _ 1 }))
Source(List(6)) ~> partition.in
partition.out(0) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Sink.fromSubscriber(c2)
ClosedShape
}).run()
c1.request(1)
c1.expectNoMsg(1.second)
c2.request(1)
c2.expectNext(6)
c1.expectComplete()
c2.expectComplete()
}
}
}

View file

@ -4,7 +4,7 @@
package akka.stream.javadsl package akka.stream.javadsl
import akka.stream._ import akka.stream._
import akka.japi.Pair import akka.japi.{ Pair, function }
import scala.annotation.unchecked.uncheckedVariance import scala.annotation.unchecked.uncheckedVariance
import akka.stream.impl.ConstantFun import akka.stream.impl.ConstantFun
@ -135,6 +135,40 @@ object Broadcast {
} }
/**
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
* to the partitioner function applied to the element
*
* '''Emits when''' all of the outputs stops backpressuring and there is an input element available
*
* '''Backpressures when''' one of the outputs backpressure
*
* '''Completes when''' upstream completes
*
* '''Cancels when'''
* when one of the downstreams cancel
*/
object Partition {
/**
* Create a new `Partition` stage with the specified input type.
*
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def create[T](outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
scaladsl.Partition(outputCount, partitioner = (t: T) partitioner.apply(t))
/**
* Create a new `Partition` stage with the specified input type.
*
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Int]): Graph[UniformFanOutShape[T, T], Unit] =
create(outputCount, partitioner)
}
/** /**
* Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer. * Fan-out the stream to several streams. Each upstream element is emitted to the first available downstream consumer.
* It will not shutdown until the subscriptions for at least * It will not shutdown until the subscriptions for at least

View file

@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
} }
object Partition {
/**
* Create a new `Partition` stage with the specified input type.
*
* @param outputPorts number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def apply[T](outputPorts: Int, partitioner: T Int): Partition[T] = new Partition(outputPorts, partitioner)
}
/**
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
* to the partitioner function applied to the element
*
* '''Emits when''' emits when an element is available from the input and the chosen output has demand
*
* '''Backpressures when''' the currently chosen output back-pressures
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when'''
* when all downstreams cancel
*/
final class Partition[T](outputPorts: Int, partitioner: T Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in: Inlet[T] = Inlet[T]("Partition.in")
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i Outlet[T]("Partition.out" + i))
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
private var outPendingElem: Any = null
private var outPendingIdx: Int = _
private var downstreamRunning = outputPorts
setHandler(in, new InHandler {
override def onPush() = {
val elem = grab(in)
val idx = partitioner(elem)
if (idx < 0 || idx >= outputPorts)
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
else if (!isClosed(out(idx))) {
if (isAvailable(out(idx))) {
push(out(idx), elem)
if (out.exists(isAvailable(_)))
pull(in)
} else {
outPendingElem = elem
outPendingIdx = idx
}
} else if (out.exists(isAvailable(_)))
pull(in)
}
override def onUpstreamFinish(): Unit = {
if (outPendingElem == null)
completeStage()
}
})
out.zipWithIndex.foreach {
case (o, idx)
setHandler(o, new OutHandler {
override def onPull() = {
if (outPendingElem != null) {
val elem = outPendingElem.asInstanceOf[T]
if (idx == outPendingIdx) {
push(o, elem)
outPendingElem = null
if (!isClosed(in)) {
if (!hasBeenPulled(in)) {
pull(in)
}
} else
completeStage()
}
} else if (!hasBeenPulled(in))
pull(in)
}
override def onDownstreamFinish(): Unit = {
downstreamRunning -= 1
if (downstreamRunning == 0)
completeStage()
else if (outPendingElem != null) {
if (idx == outPendingIdx) {
outPendingElem = null
if (!hasBeenPulled(in))
pull(in)
}
}
}
})
}
}
override def toString = s"Partition($outputPorts)"
}
object Balance { object Balance {
/** /**
* Create a new `Balance` with the specified number of output ports. * Create a new `Balance` with the specified number of output ports.