+str #19176 Added Partition Graph Stage with tests
This commit is contained in:
parent
4de3f63b93
commit
c15f4d15f4
3 changed files with 313 additions and 1 deletions
|
|
@ -464,6 +464,109 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
|
|||
|
||||
}
|
||||
|
||||
object Partition {
|
||||
|
||||
/**
|
||||
* Create a new `Partition` stage with the specified input type.
|
||||
*
|
||||
* @param outputPorts number of output ports
|
||||
* @param partitioner function deciding which output each element will be targeted
|
||||
*/
|
||||
def apply[T](outputPorts: Int, partitioner: T ⇒ Int): Partition[T] = new Partition(outputPorts, partitioner)
|
||||
}
|
||||
|
||||
/**
|
||||
* Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according
|
||||
* to the partitioner function applied to the element
|
||||
*
|
||||
* '''Emits when''' emits when an element is available from the input and the chosen output has demand
|
||||
*
|
||||
* '''Backpressures when''' the currently chosen output back-pressures
|
||||
*
|
||||
* '''Completes when''' upstream completes and no output is pending
|
||||
*
|
||||
* '''Cancels when'''
|
||||
* when all downstreams cancel
|
||||
*/
|
||||
|
||||
final class Partition[T](outputPorts: Int, partitioner: T ⇒ Int) extends GraphStage[UniformFanOutShape[T, T]] {
|
||||
|
||||
val in: Inlet[T] = Inlet[T]("Partition.in")
|
||||
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i ⇒ Outlet[T]("Partition.out" + i))
|
||||
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||
private var outPendingElem: Any = null
|
||||
private var outPendingIdx: Int = _
|
||||
private var downstreamRunning = outputPorts
|
||||
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush() = {
|
||||
val elem = grab(in)
|
||||
val idx = partitioner(elem)
|
||||
if (idx < 0 || idx >= outputPorts)
|
||||
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
|
||||
else if (!isClosed(out(idx))) {
|
||||
if (isAvailable(out(idx))) {
|
||||
push(out(idx), elem)
|
||||
if (out.exists(isAvailable(_)))
|
||||
pull(in)
|
||||
} else {
|
||||
outPendingElem = elem
|
||||
outPendingIdx = idx
|
||||
}
|
||||
|
||||
} else if (out.exists(isAvailable(_)))
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (outPendingElem == null)
|
||||
completeStage()
|
||||
}
|
||||
})
|
||||
|
||||
out.zipWithIndex.foreach {
|
||||
case (o, idx) ⇒
|
||||
setHandler(o, new OutHandler {
|
||||
override def onPull() = {
|
||||
|
||||
if (outPendingElem != null) {
|
||||
val elem = outPendingElem.asInstanceOf[T]
|
||||
if (idx == outPendingIdx) {
|
||||
push(o, elem)
|
||||
outPendingElem = null
|
||||
if (!isClosed(in)) {
|
||||
if (!hasBeenPulled(in)) {
|
||||
pull(in)
|
||||
}
|
||||
} else
|
||||
completeStage()
|
||||
}
|
||||
} else if (!hasBeenPulled(in))
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
downstreamRunning -= 1
|
||||
if (downstreamRunning == 0)
|
||||
completeStage()
|
||||
else if (outPendingElem != null) {
|
||||
if (idx == outPendingIdx) {
|
||||
outPendingElem = null
|
||||
if (!hasBeenPulled(in))
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = s"Partition($outputPorts)"
|
||||
|
||||
}
|
||||
|
||||
object Balance {
|
||||
/**
|
||||
* Create a new `Balance` with the specified number of output ports.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue