+str #22711 adding GroupedWeightedWithin
This commit is contained in:
parent
87b28d0dc5
commit
edee4ba409
11 changed files with 332 additions and 48 deletions
|
|
@ -1399,52 +1399,90 @@ private[stream] object Collect {
|
|||
|
||||
}
|
||||
|
||||
@InternalApi private[akka] object GroupedWeightedWithin {
|
||||
val groupedWeightedWithinTimer = "GroupedWeightedWithinTimer"
|
||||
}
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
@InternalApi private[akka] final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
require(d > Duration.Zero)
|
||||
@InternalApi private[akka] final class GroupedWeightedWithin[T](val maxWeight: Long, costFn: T ⇒ Long, val interval: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(maxWeight > 0, "maxWeight must be greater than 0")
|
||||
require(interval > Duration.Zero)
|
||||
|
||||
val in = Inlet[T]("in")
|
||||
val out = Outlet[immutable.Seq[T]]("out")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.groupedWithin
|
||||
override def initialAttributes = DefaultAttributes.groupedWeightedWithin
|
||||
|
||||
val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
|
||||
|
||||
private val buf: VectorBuilder[T] = new VectorBuilder
|
||||
private var pending: T = null.asInstanceOf[T]
|
||||
private var pendingWeight: Long = 0L
|
||||
// True if:
|
||||
// - buf is nonEmpty
|
||||
// AND
|
||||
// - timer fired OR group is full
|
||||
private var groupClosed = false
|
||||
// - (timer fired
|
||||
// OR
|
||||
// totalWeight >= maxWeight
|
||||
// OR
|
||||
// pending != null
|
||||
// OR
|
||||
// upstream completed)
|
||||
private var pushEagerly = false
|
||||
private var groupEmitted = true
|
||||
private var finished = false
|
||||
private var elements = 0
|
||||
|
||||
private val GroupedWithinTimer = "GroupedWithinTimer"
|
||||
private var totalWeight = 0L
|
||||
|
||||
override def preStart() = {
|
||||
schedulePeriodically(GroupedWithinTimer, d)
|
||||
schedulePeriodically(GroupedWeightedWithin.groupedWeightedWithinTimer, interval)
|
||||
pull(in)
|
||||
}
|
||||
|
||||
private def nextElement(elem: T): Unit = {
|
||||
groupEmitted = false
|
||||
buf += elem
|
||||
elements += 1
|
||||
if (elements == n) {
|
||||
schedulePeriodically(GroupedWithinTimer, d)
|
||||
closeGroup()
|
||||
} else pull(in)
|
||||
val cost = costFn(elem)
|
||||
if (cost < 0) failStage(new IllegalArgumentException(s"Negative weight [$cost] for element [$elem] is not allowed"))
|
||||
else {
|
||||
if (totalWeight + cost <= maxWeight) {
|
||||
buf += elem
|
||||
totalWeight += cost
|
||||
|
||||
if (totalWeight < maxWeight) pull(in)
|
||||
else {
|
||||
// `totalWeight >= maxWeight` which means that downstream can get the next group.
|
||||
if (!isAvailable(out)) {
|
||||
// We should emit group when downstream becomes available
|
||||
pushEagerly = true
|
||||
// we want to pull anyway, since we allow for zero weight elements
|
||||
// but since `emitGroup()` will pull internally (by calling `startNewGroup()`)
|
||||
// we also have to pull if downstream hasn't yet requested an element.
|
||||
pull(in)
|
||||
} else {
|
||||
schedulePeriodically(GroupedWeightedWithin.groupedWeightedWithinTimer, interval)
|
||||
emitGroup()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (totalWeight == 0L) {
|
||||
buf += elem
|
||||
totalWeight += cost
|
||||
pushEagerly = true
|
||||
} else {
|
||||
pending = elem
|
||||
pendingWeight = cost
|
||||
}
|
||||
schedulePeriodically(GroupedWeightedWithin.groupedWeightedWithinTimer, interval)
|
||||
tryCloseGroup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def closeGroup(): Unit = {
|
||||
groupClosed = true
|
||||
private def tryCloseGroup(): Unit = {
|
||||
if (isAvailable(out)) emitGroup()
|
||||
else if (pending != null || finished) pushEagerly = true
|
||||
}
|
||||
|
||||
private def emitGroup(): Unit = {
|
||||
|
|
@ -1452,30 +1490,41 @@ private[stream] object Collect {
|
|||
push(out, buf.result())
|
||||
buf.clear()
|
||||
if (!finished) startNewGroup()
|
||||
else if (pending != null) emit(out, Vector(pending), () ⇒ completeStage())
|
||||
else completeStage()
|
||||
}
|
||||
|
||||
private def startNewGroup(): Unit = {
|
||||
elements = 0
|
||||
groupClosed = false
|
||||
if (pending != null) {
|
||||
totalWeight = pendingWeight
|
||||
pendingWeight = 0L
|
||||
buf += pending
|
||||
pending = null.asInstanceOf[T]
|
||||
groupEmitted = false
|
||||
} else {
|
||||
totalWeight = 0
|
||||
}
|
||||
pushEagerly = false
|
||||
if (isAvailable(in)) nextElement(grab(in))
|
||||
else if (!hasBeenPulled(in)) pull(in)
|
||||
}
|
||||
|
||||
override def onPush(): Unit = {
|
||||
if (!groupClosed) nextElement(grab(in)) // otherwise keep the element for next round
|
||||
if (pending == null) nextElement(grab(in)) // otherwise keep the element for next round
|
||||
}
|
||||
|
||||
override def onPull(): Unit = if (groupClosed) emitGroup()
|
||||
override def onPull(): Unit = if (pushEagerly) emitGroup()
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
finished = true
|
||||
if (groupEmitted) completeStage()
|
||||
else closeGroup()
|
||||
else tryCloseGroup()
|
||||
}
|
||||
|
||||
override protected def onTimer(timerKey: Any) = if (elements > 0) closeGroup()
|
||||
|
||||
override protected def onTimer(timerKey: Any) = if (totalWeight > 0) {
|
||||
if (isAvailable(out)) emitGroup()
|
||||
else pushEagerly = true
|
||||
}
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue