migrate Fold, Sliding, Grouped to GraphStage (#20914)
This commit is contained in:
parent
19f6c0c61c
commit
9683e4bc58
6 changed files with 154 additions and 100 deletions
|
|
@ -169,7 +169,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(Cancel, OnComplete, OnNext(3)))
|
||||
}
|
||||
|
||||
"implement fold" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) {
|
||||
"implement fold" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) {
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
downstream.requestOne()
|
||||
|
|
@ -188,7 +188,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(OnNext(3), OnComplete))
|
||||
}
|
||||
|
||||
"implement fold with proper cancel" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) {
|
||||
"implement fold with proper cancel" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) {
|
||||
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(Cancel))
|
||||
}
|
||||
|
||||
"work if fold completes while not in a push position" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) {
|
||||
"work if fold completes while not in a push position" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) {
|
||||
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(OnComplete, OnNext(0)))
|
||||
}
|
||||
|
||||
"implement grouped" in new OneBoundedSetup[Int](Seq(Grouped(3))) {
|
||||
"implement grouped" in new OneBoundedSetup[Int](Grouped(3)) {
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
downstream.requestOne()
|
||||
|
|
@ -490,9 +490,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(OnNext(1), OnComplete))
|
||||
}
|
||||
|
||||
"work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](Seq(
|
||||
new PushFinishStage,
|
||||
Fold(0, (x: Int, y: Int) ⇒ x + y, stoppingDecider))) {
|
||||
"work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](
|
||||
(new PushFinishStage).toGS,
|
||||
Fold(0, (x: Int, y: Int) ⇒ x + y)) {
|
||||
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
|
|
@ -618,6 +618,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
push(out, lastElem)
|
||||
}
|
||||
|
||||
// note that the default value of lastElem will be always pushed if the upstream closed at the very begining without a pulling
|
||||
override def onPull(): Unit = {
|
||||
if (isClosed(in)) {
|
||||
push(out, lastElem)
|
||||
|
|
|
|||
|
|
@ -100,10 +100,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(OnNext(114)))
|
||||
}
|
||||
|
||||
"resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq(
|
||||
Map((x: Int) ⇒ x + 1, resumingDecider),
|
||||
Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider),
|
||||
Grouped(3))) {
|
||||
"resume when Map throws before Grouped" in new OneBoundedSetup[Int](
|
||||
Map((x: Int) ⇒ x + 1, resumingDecider).toGS,
|
||||
Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider).toGS,
|
||||
Grouped(3)) {
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
|
@ -120,10 +120,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
lastEvents() should be(Set(OnNext(Vector(13, 14, 15))))
|
||||
}
|
||||
|
||||
"complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq(
|
||||
Map((x: Int) ⇒ x + 1, resumingDecider),
|
||||
Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider),
|
||||
Grouped(1000))) {
|
||||
"complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](
|
||||
Map((x: Int) ⇒ x + 1, resumingDecider).toGS,
|
||||
Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider).toGS,
|
||||
Grouped(1000)) {
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class LifecycleInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
|
||||
"postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[String](
|
||||
new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"),
|
||||
Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider).toGS) {
|
||||
Fold("", (x: String, y: String) ⇒ x + y)) {
|
||||
|
||||
lastEvents() should be(Set.empty)
|
||||
|
||||
|
|
|
|||
|
|
@ -160,22 +160,6 @@ object Stages {
|
|||
override def create(attr: Attributes): Stage[In, Out] = fusing.Map(f, supervision(attr))
|
||||
}
|
||||
|
||||
final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n)
|
||||
}
|
||||
|
||||
final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
require(step > 0, "step must be greater than 0")
|
||||
|
||||
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step)
|
||||
}
|
||||
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] {
|
||||
override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr))
|
||||
}
|
||||
|
||||
final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] {
|
||||
require(size > 0, s"Buffer size must be larger than zero but was [$size]")
|
||||
override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import akka.event.{ LogSource, Logging, LoggingAdapter }
|
|||
import akka.stream.Attributes.{ InputBuffer, LogLevels }
|
||||
import akka.stream.OverflowStrategies._
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
import akka.stream.impl.{ Buffer ⇒ BufferImpl, ReactiveStreamsCompliance }
|
||||
import akka.stream.scaladsl.Source
|
||||
import akka.stream.impl.{ Buffer ⇒ BufferImpl, Stages, ReactiveStreamsCompliance }
|
||||
import akka.stream.scaladsl.{ SourceQueue, Source }
|
||||
import akka.stream.stage._
|
||||
import akka.stream.{ Supervision, _ }
|
||||
import scala.annotation.tailrec
|
||||
|
|
@ -347,23 +347,48 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
|
||||
private[this] var aggregator: Out = zero
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphStage[FlowShape[In, Out]] {
|
||||
|
||||
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
|
||||
aggregator = f(aggregator, elem)
|
||||
ctx.pull()
|
||||
}
|
||||
val in = Inlet[In]("Fold.in")
|
||||
val out = Outlet[Out]("Fold.out")
|
||||
override val shape: FlowShape[In, Out] = FlowShape(in, out)
|
||||
|
||||
override def onPull(ctx: Context[Out]): SyncDirective =
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(aggregator)
|
||||
else ctx.pull()
|
||||
override val initialAttributes = DefaultAttributes.fold
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination()
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
private var aggregator: Out = zero
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
override def onResume(t: Throwable): Unit = {
|
||||
aggregator = zero
|
||||
}
|
||||
|
||||
override def restart(): Fold[In, Out] = copy()
|
||||
override def onPush(): Unit = withSupervision(() ⇒ grab(in)) match {
|
||||
case Some(elem) ⇒ {
|
||||
aggregator = f(aggregator, elem)
|
||||
pull(in)
|
||||
}
|
||||
case None ⇒ pull(in)
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
if (isClosed(in)) {
|
||||
push(out, aggregator)
|
||||
completeStage()
|
||||
} else {
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (isAvailable(out)) {
|
||||
push(out, aggregator)
|
||||
completeStage()
|
||||
}
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -415,36 +440,55 @@ final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) ext
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Grouped[T](n: Int) extends PushPullStage[T, immutable.Seq[T]] {
|
||||
private val buf = {
|
||||
val b = Vector.newBuilder[T]
|
||||
b.sizeHint(n)
|
||||
b
|
||||
}
|
||||
private var left = n
|
||||
final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
|
||||
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
|
||||
buf += elem
|
||||
left -= 1
|
||||
if (left == 0) {
|
||||
val emit = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
ctx.push(emit)
|
||||
} else ctx.pull()
|
||||
val in = Inlet[T]("Grouped.in")
|
||||
val out = Outlet[immutable.Seq[T]]("Grouped.out")
|
||||
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
|
||||
|
||||
override protected val initialAttributes: Attributes = DefaultAttributes.grouped
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private val buf = {
|
||||
val b = Vector.newBuilder[T]
|
||||
b.sizeHint(n)
|
||||
b
|
||||
}
|
||||
var left = n
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buf += grab(in)
|
||||
left -= 1
|
||||
if (left == 0) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
push(out, elements)
|
||||
} else {
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
// This means the buf is filled with some elements but not enough (left < n) to group together.
|
||||
// Since the upstream has finished we have to push them to downstream though.
|
||||
if (left < n) {
|
||||
val elements = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
push(out, elements)
|
||||
}
|
||||
completeStage()
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
|
||||
if (ctx.isFinishing) {
|
||||
val elem = buf.result()
|
||||
buf.clear()
|
||||
left = n
|
||||
ctx.pushAndFinish(elem)
|
||||
} else ctx.pull()
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
|
||||
if (left == n) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -482,34 +526,59 @@ final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStag
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
final case class Sliding[T](n: Int, step: Int) extends PushPullStage[T, immutable.Seq[T]] {
|
||||
private var buf = Vector.empty[T]
|
||||
final case class Sliding[T](n: Int, step: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
require(n > 0, "n must be greater than 0")
|
||||
require(step > 0, "step must be greater than 0")
|
||||
|
||||
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
|
||||
buf :+= elem
|
||||
if (buf.size < n) {
|
||||
ctx.pull()
|
||||
} else if (buf.size == n) {
|
||||
ctx.push(buf)
|
||||
} else if (step > n) {
|
||||
if (buf.size == step)
|
||||
buf = Vector.empty
|
||||
ctx.pull()
|
||||
} else {
|
||||
buf = buf.drop(step)
|
||||
if (buf.size == n) ctx.push(buf)
|
||||
else ctx.pull()
|
||||
val in = Inlet[T]("Sliding.in")
|
||||
val out = Outlet[immutable.Seq[T]]("Sliding.out")
|
||||
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
|
||||
|
||||
override protected val initialAttributes: Attributes = DefaultAttributes.sliding
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var buf = Vector.empty[T]
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buf :+= grab(in)
|
||||
if (buf.size < n) {
|
||||
pull(in)
|
||||
} else if (buf.size == n) {
|
||||
push(out, buf)
|
||||
} else if (step <= n) {
|
||||
buf = buf.drop(step)
|
||||
if (buf.size == n) {
|
||||
push(out, buf)
|
||||
} else pull(in)
|
||||
} else if (step > n) {
|
||||
if (buf.size == step) {
|
||||
buf = buf.drop(step)
|
||||
}
|
||||
pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
|
||||
// We can finish current stage directly if:
|
||||
// 1. the buf is empty or
|
||||
// 2. when the step size is greater than the sliding size (step > n) and current stage is in between
|
||||
// two sliding (ie. buf.size >= n && buf.size < step).
|
||||
//
|
||||
// Otherwise it means there is still a not finished sliding so we have to push them before finish current stage.
|
||||
if (buf.size < n && buf.size > 0) {
|
||||
push(out, buf)
|
||||
}
|
||||
completeStage()
|
||||
}
|
||||
|
||||
this.setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
|
||||
if (!ctx.isFinishing) ctx.pull()
|
||||
else if (buf.size >= n) ctx.finish()
|
||||
else ctx.pushAndFinish(buf)
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
|
||||
if (buf.isEmpty) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -672,7 +672,7 @@ trait FlowOps[+Out, +Mat] {
|
|||
*
|
||||
* '''Cancels when''' downstream cancels
|
||||
*/
|
||||
def grouped(n: Int): Repr[immutable.Seq[Out]] = andThen(Grouped(n))
|
||||
def grouped(n: Int): Repr[immutable.Seq[Out]] = via(Grouped(n))
|
||||
|
||||
/**
|
||||
* Ensure stream boundedness by limiting the number of elements from upstream.
|
||||
|
|
@ -736,7 +736,7 @@ trait FlowOps[+Out, +Mat] {
|
|||
*
|
||||
* '''Cancels when''' downstream cancels
|
||||
*/
|
||||
def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = andThen(Sliding(n, step))
|
||||
def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = via(Sliding(n, step))
|
||||
|
||||
/**
|
||||
* Similar to `fold` but is not a terminal operation,
|
||||
|
|
@ -777,7 +777,7 @@ trait FlowOps[+Out, +Mat] {
|
|||
*
|
||||
* See also [[FlowOps.scan]]
|
||||
*/
|
||||
def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = andThen(Fold(zero, f))
|
||||
def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = via(Fold(zero, f))
|
||||
|
||||
/**
|
||||
* Similar to `fold` but uses first element as zero element.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue