Add GroupedWeighted FlowOp and Make Grouped use GroupedWeighted #29066

This commit is contained in:
Michael Marshall 2021-01-27 10:03:30 -07:00 committed by GitHub
parent dffd7099fd
commit 4d9b25579d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 351 additions and 30 deletions

View file

@ -762,34 +762,37 @@ private[stream] object Collect {
/**
* INTERNAL API
*/
@InternalApi private[akka] final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
require(n > 0, "n must be greater than 0")
@InternalApi private[akka] final case class GroupedWeighted[T](minWeight: Long, costFn: T => Long)
extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
require(minWeight > 0, "minWeight must be greater than 0")
val in = Inlet[T]("Grouped.in")
val out = Outlet[immutable.Seq[T]]("Grouped.out")
val in = Inlet[T]("GroupedWeighted.in")
val out = Outlet[immutable.Seq[T]]("GroupedWeighted.out")
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
override protected val initialAttributes: Attributes = DefaultAttributes.grouped
override def initialAttributes: Attributes = DefaultAttributes.groupedWeighted
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private val buf = {
val b = Vector.newBuilder[T]
b.sizeHint(n)
b
}
var left = n
private val buf = Vector.newBuilder[T]
var left: Long = minWeight
override def onPush(): Unit = {
buf += grab(in)
left -= 1
if (left == 0) {
val elements = buf.result()
buf.clear()
left = n
push(out, elements)
} else {
pull(in)
val elem = grab(in)
val cost = costFn(elem)
if (cost < 0L)
failStage(new IllegalArgumentException(s"Negative weight [$cost] for element [$elem] is not allowed"))
else {
buf += elem
left -= cost
if (left <= 0) {
val elements = buf.result()
buf.clear()
left = minWeight
push(out, elements)
} else {
pull(in)
}
}
}
@ -798,12 +801,11 @@ private[stream] object Collect {
}
override def onUpstreamFinish(): Unit = {
// This means the buf is filled with some elements but not enough (left < n) to group together.
// Since the upstream has finished we have to push them to downstream though.
if (left < n) {
val elements = buf.result()
// Since the upstream has finished we have to push any buffered elements downstream.
val elements = buf.result()
if (elements.nonEmpty) {
buf.clear()
left = n
left = minWeight
push(out, elements)
}
completeStage()