+str #19176 Added Partition Graph Stage with tests

This commit is contained in:
israel 2016-01-14 18:08:18 +02:00
parent 4de3f63b93
commit c15f4d15f4
3 changed files with 313 additions and 1 deletions

View file

@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
}
object Partition {
/**
* Create a new `Partition` stage with the specified input type.
*
* @param outputPorts number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def apply[T](outputPorts: Int, partitioner: T Int): Partition[T] = new Partition(outputPorts, partitioner)
}
/**
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
* to the partitioner function applied to the element
*
* '''Emits when''' emits when an element is available from the input and the chosen output has demand
*
* '''Backpressures when''' the currently chosen output back-pressures
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when'''
* when all downstreams cancel
*/
final class Partition[T](outputPorts: Int, partitioner: T Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in: Inlet[T] = Inlet[T]("Partition.in")
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i Outlet[T]("Partition.out" + i))
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
private var outPendingElem: Any = null
private var outPendingIdx: Int = _
private var downstreamRunning = outputPorts
setHandler(in, new InHandler {
override def onPush() = {
val elem = grab(in)
val idx = partitioner(elem)
if (idx < 0 || idx >= outputPorts)
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
else if (!isClosed(out(idx))) {
if (isAvailable(out(idx))) {
push(out(idx), elem)
if (out.exists(isAvailable(_)))
pull(in)
} else {
outPendingElem = elem
outPendingIdx = idx
}
} else if (out.exists(isAvailable(_)))
pull(in)
}
override def onUpstreamFinish(): Unit = {
if (outPendingElem == null)
completeStage()
}
})
out.zipWithIndex.foreach {
case (o, idx)
setHandler(o, new OutHandler {
override def onPull() = {
if (outPendingElem != null) {
val elem = outPendingElem.asInstanceOf[T]
if (idx == outPendingIdx) {
push(o, elem)
outPendingElem = null
if (!isClosed(in)) {
if (!hasBeenPulled(in)) {
pull(in)
}
} else
completeStage()
}
} else if (!hasBeenPulled(in))
pull(in)
}
override def onDownstreamFinish(): Unit = {
downstreamRunning -= 1
if (downstreamRunning == 0)
completeStage()
else if (outPendingElem != null) {
if (idx == outPendingIdx) {
outPendingElem = null
if (!hasBeenPulled(in))
pull(in)
}
}
}
})
}
}
override def toString = s"Partition($outputPorts)"
}
object Balance {
/**
* Create a new `Balance` with the specified number of output ports.