migrate Fold, Sliding, Grouped to GraphStage (#20914)
This commit is contained in:
parent
19f6c0c61c
commit
9683e4bc58
6 changed files with 154 additions and 100 deletions
|
|
@ -8,8 +8,8 @@ import akka.event.{ LogSource, Logging, LoggingAdapter }
|
|||
import akka.stream.Attributes.{ InputBuffer, LogLevels }
|
||||
import akka.stream.OverflowStrategies._
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
import akka.stream.impl.{ Buffer ⇒ BufferImpl, ReactiveStreamsCompliance }
|
||||
import akka.stream.scaladsl.Source
|
||||
import akka.stream.impl.{ Buffer ⇒ BufferImpl, Stages, ReactiveStreamsCompliance }
|
||||
import akka.stream.scaladsl.{ SourceQueue, Source }
|
||||
import akka.stream.stage._
|
||||
import akka.stream.{ Supervision, _ }
|
||||
import scala.annotation.tailrec
|
||||
|
|
@ -347,23 +347,48 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
|
||||
private[this] var aggregator: Out = zero
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphStage[FlowShape[In, Out]] {
|
||||
|
||||
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
|
||||
aggregator = f(aggregator, elem)
|
||||
ctx.pull()
|
||||
}
|
||||
val in = Inlet[In]("Fold.in")
|
||||
val out = Outlet[Out]("Fold.out")
|
||||
override val shape: FlowShape[In, Out] = FlowShape(in, out)
|
||||
|
||||
override def onPull(ctx: Context[Out]): SyncDirective =
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(aggregator)
|
||||
else ctx.pull()
|
||||
override val initialAttributes = DefaultAttributes.fold
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination()
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
private var aggregator: Out = zero
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
override def onResume(t: Throwable): Unit = {
|
||||
aggregator = zero
|
||||
}
|
||||
|
||||
override def restart(): Fold[In, Out] = copy()
|
||||
override def onPush(): Unit = withSupervision(() ⇒ grab(in)) match {
|
||||
case Some(elem) ⇒ {
|
||||
aggregator = f(aggregator, elem)
|
||||
pull(in)
|
||||
}
|
||||
case None ⇒ pull(in)
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
if (isClosed(in)) {
|
||||
push(out, aggregator)
|
||||
completeStage()
|
||||
} else {
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (isAvailable(out)) {
|
||||
push(out, aggregator)
|
||||
completeStage()
|
||||
}
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -415,36 +440,55 @@ final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) ext
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Grouped[T](n: Int) extends PushPullStage[T, immutable.Seq[T]] {
|
||||
private val buf = {
|
||||
val b = Vector.newBuilder[T]
|
||||
b.sizeHint(n)
|
||||
b
|
||||
}
|
||||
private var left = n
|
||||
final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
|
||||
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
|
||||
buf += elem
|
||||
left -= 1
|
||||
if (left == 0) {
|
||||
val emit = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
ctx.push(emit)
|
||||
} else ctx.pull()
|
||||
val in = Inlet[T]("Grouped.in")
|
||||
val out = Outlet[immutable.Seq[T]]("Grouped.out")
|
||||
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
|
||||
|
||||
override protected val initialAttributes: Attributes = DefaultAttributes.grouped
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private val buf = {
|
||||
val b = Vector.newBuilder[T]
|
||||
b.sizeHint(n)
|
||||
b
|
||||
}
|
||||
var left = n
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buf += grab(in)
|
||||
left -= 1
|
||||
if (left == 0) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
push(out, elements)
|
||||
} else {
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
// This means the buf is filled with some elements but not enough (left < n) to group together.
|
||||
// Since the upstream has finished we have to push them to downstream though.
|
||||
if (left < n) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
push(out, elements)
|
||||
}
|
||||
completeStage()
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
|
||||
if (ctx.isFinishing) {
|
||||
val elem = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
ctx.pushAndFinish(elem)
|
||||
} else ctx.pull()
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
|
||||
if (left == n) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -482,34 +526,59 @@ final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStag
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Sliding[T](n: Int, step: Int) extends PushPullStage[T, immutable.Seq[T]] {
|
||||
private var buf = Vector.empty[T]
|
||||
final case class Sliding[T](n: Int, step: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
require(step > 0, "step must be greater than 0")
|
||||
|
||||
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
|
||||
buf :+= elem
|
||||
if (buf.size < n) {
|
||||
ctx.pull()
|
||||
} else if (buf.size == n) {
|
||||
ctx.push(buf)
|
||||
} else if (step > n) {
|
||||
if (buf.size == step)
|
||||
buf = Vector.empty
|
||||
ctx.pull()
|
||||
} else {
|
||||
buf = buf.drop(step)
|
||||
if (buf.size == n) ctx.push(buf)
|
||||
else ctx.pull()
|
||||
val in = Inlet[T]("Sliding.in")
|
||||
val out = Outlet[immutable.Seq[T]]("Sliding.out")
|
||||
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
|
||||
|
||||
override protected val initialAttributes: Attributes = DefaultAttributes.sliding
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var buf = Vector.empty[T]
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buf :+= grab(in)
|
||||
if (buf.size < n) {
|
||||
pull(in)
|
||||
} else if (buf.size == n) {
|
||||
push(out, buf)
|
||||
} else if (step <= n) {
|
||||
buf = buf.drop(step)
|
||||
if (buf.size == n) {
|
||||
push(out, buf)
|
||||
} else pull(in)
|
||||
} else if (step > n) {
|
||||
if (buf.size == step) {
|
||||
buf = buf.drop(step)
|
||||
}
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
|
||||
// We can finish current stage directly if:
|
||||
// 1. the buf is empty or
|
||||
// 2. when the step size is greater than the sliding size (step > n) and current stage is in between
|
||||
// two sliding (ie. buf.size >= n && buf.size < step).
|
||||
//
|
||||
// Otherwise it means there is still a not finished sliding so we have to push them before finish current stage.
|
||||
if (buf.size < n && buf.size > 0) {
|
||||
push(out, buf)
|
||||
}
|
||||
completeStage()
|
||||
}
|
||||
|
||||
this.setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
|
||||
if (!ctx.isFinishing) ctx.pull()
|
||||
else if (buf.size >= n) ctx.finish()
|
||||
else ctx.pushAndFinish(buf)
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
|
||||
if (buf.isEmpty) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue