Add GroupedWeighted FlowOp and Make Grouped use GroupedWeighted #29066
This commit is contained in:
parent
dffd7099fd
commit
4d9b25579d
18 changed files with 351 additions and 30 deletions
|
|
@ -762,34 +762,37 @@ private[stream] object Collect {
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
@InternalApi private[akka] final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
@InternalApi private[akka] final case class GroupedWeighted[T](minWeight: Long, costFn: T => Long)
|
||||
extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(minWeight > 0, "minWeight must be greater than 0")
|
||||
|
||||
val in = Inlet[T]("Grouped.in")
|
||||
val out = Outlet[immutable.Seq[T]]("Grouped.out")
|
||||
val in = Inlet[T]("GroupedWeighted.in")
|
||||
val out = Outlet[immutable.Seq[T]]("GroupedWeighted.out")
|
||||
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
|
||||
|
||||
override protected val initialAttributes: Attributes = DefaultAttributes.grouped
|
||||
override def initialAttributes: Attributes = DefaultAttributes.groupedWeighted
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private val buf = {
|
||||
val b = Vector.newBuilder[T]
|
||||
b.sizeHint(n)
|
||||
b
|
||||
}
|
||||
var left = n
|
||||
private val buf = Vector.newBuilder[T]
|
||||
var left: Long = minWeight
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buf += grab(in)
|
||||
left -= 1
|
||||
if (left == 0) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
push(out, elements)
|
||||
} else {
|
||||
pull(in)
|
||||
val elem = grab(in)
|
||||
val cost = costFn(elem)
|
||||
if (cost < 0L)
|
||||
failStage(new IllegalArgumentException(s"Negative weight [$cost] for element [$elem] is not allowed"))
|
||||
else {
|
||||
buf += elem
|
||||
left -= cost
|
||||
if (left <= 0) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = minWeight
|
||||
push(out, elements)
|
||||
} else {
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -798,12 +801,11 @@ private[stream] object Collect {
|
|||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
// This means the buf is filled with some elements but not enough (left < n) to group together.
|
||||
// Since the upstream has finished we have to push them to downstream though.
|
||||
if (left < n) {
|
||||
val elements = buf.result()
|
||||
// Since the upstream has finished we have to push any buffered elements downstream.
|
||||
val elements = buf.result()
|
||||
if (elements.nonEmpty) {
|
||||
buf.clear()
|
||||
left = n
|
||||
left = minWeight
|
||||
push(out, elements)
|
||||
}
|
||||
completeStage()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue