migrate Fold, Sliding, Grouped to GraphStage (#20914)

This commit is contained in:
zhxiaog 2016-07-08 14:22:18 +02:00 committed by Konrad Malawski
parent 19f6c0c61c
commit 9683e4bc58
6 changed files with 154 additions and 100 deletions

View file

@ -169,7 +169,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(Cancel, OnComplete, OnNext(3)))
}
"implement fold" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) agg + x, stoppingDecider))) {
"implement fold" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) agg + x)) {
lastEvents() should be(Set.empty)
downstream.requestOne()
@ -188,7 +188,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(3), OnComplete))
}
"implement fold with proper cancel" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) agg + x, stoppingDecider))) {
"implement fold with proper cancel" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) agg + x)) {
lastEvents() should be(Set.empty)
@ -208,7 +208,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(Cancel))
}
"work if fold completes while not in a push position" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) agg + x, stoppingDecider))) {
"work if fold completes while not in a push position" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) agg + x)) {
lastEvents() should be(Set.empty)
@ -219,7 +219,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnComplete, OnNext(0)))
}
"implement grouped" in new OneBoundedSetup[Int](Seq(Grouped(3))) {
"implement grouped" in new OneBoundedSetup[Int](Grouped(3)) {
lastEvents() should be(Set.empty)
downstream.requestOne()
@ -490,9 +490,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(1), OnComplete))
}
"work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](Seq(
new PushFinishStage,
Fold(0, (x: Int, y: Int) x + y, stoppingDecider))) {
"work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](
(new PushFinishStage).toGS,
Fold(0, (x: Int, y: Int) x + y)) {
lastEvents() should be(Set.empty)
@ -618,6 +618,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
push(out, lastElem)
}
// note that the default value of lastElem will be always pushed if the upstream closed at the very begining without a pulling
override def onPull(): Unit = {
if (isClosed(in)) {
push(out, lastElem)

View file

@ -100,10 +100,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(114)))
}
"resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq(
Map((x: Int) x + 1, resumingDecider),
Map((x: Int) if (x <= 0) throw TE else x + 10, resumingDecider),
Grouped(3))) {
"resume when Map throws before Grouped" in new OneBoundedSetup[Int](
Map((x: Int) x + 1, resumingDecider).toGS,
Map((x: Int) if (x <= 0) throw TE else x + 10, resumingDecider).toGS,
Grouped(3)) {
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
@ -120,10 +120,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(Vector(13, 14, 15))))
}
"complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq(
Map((x: Int) x + 1, resumingDecider),
Map((x: Int) if (x <= 0) throw TE else x + 10, resumingDecider),
Grouped(1000))) {
"complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](
Map((x: Int) x + 1, resumingDecider).toGS,
Map((x: Int) if (x <= 0) throw TE else x + 10, resumingDecider).toGS,
Grouped(1000)) {
downstream.requestOne()
lastEvents() should be(Set(RequestOne))

View file

@ -128,7 +128,7 @@ class LifecycleInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
"postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[String](
new PushFinishStage(onPostStop = () testActor ! "stop"),
Fold("", (x: String, y: String) x + y, stoppingDecider).toGS) {
Fold("", (x: String, y: String) x + y)) {
lastEvents() should be(Set.empty)

View file

@ -160,22 +160,6 @@ object Stages {
override def create(attr: Attributes): Stage[In, Out] = fusing.Map(f, supervision(attr))
}
final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] {
require(n > 0, "n must be greater than 0")
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n)
}
final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] {
require(n > 0, "n must be greater than 0")
require(step > 0, "step must be greater than 0")
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step)
}
final case class Fold[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr))
}
final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] {
require(size > 0, s"Buffer size must be larger than zero but was [$size]")
override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy)

View file

@ -8,8 +8,8 @@ import akka.event.{ LogSource, Logging, LoggingAdapter }
import akka.stream.Attributes.{ InputBuffer, LogLevels }
import akka.stream.OverflowStrategies._
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.impl.{ Buffer BufferImpl, ReactiveStreamsCompliance }
import akka.stream.scaladsl.Source
import akka.stream.impl.{ Buffer BufferImpl, Stages, ReactiveStreamsCompliance }
import akka.stream.scaladsl.{ SourceQueue, Source }
import akka.stream.stage._
import akka.stream.{ Supervision, _ }
import scala.annotation.tailrec
@ -347,23 +347,48 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
/**
* INTERNAL API
*/
final case class Fold[In, Out](zero: Out, f: (Out, In) Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
private[this] var aggregator: Out = zero
final case class Fold[In, Out](zero: Out, f: (Out, In) Out) extends GraphStage[FlowShape[In, Out]] {
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
aggregator = f(aggregator, elem)
ctx.pull()
}
val in = Inlet[In]("Fold.in")
val out = Outlet[Out]("Fold.out")
override val shape: FlowShape[In, Out] = FlowShape(in, out)
override def onPull(ctx: Context[Out]): SyncDirective =
if (ctx.isFinishing) ctx.pushAndFinish(aggregator)
else ctx.pull()
override val initialAttributes = DefaultAttributes.fold
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination()
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
private var aggregator: Out = zero
override def decide(t: Throwable): Supervision.Directive = decider(t)
override def onResume(t: Throwable): Unit = {
aggregator = zero
}
override def restart(): Fold[In, Out] = copy()
override def onPush(): Unit = withSupervision(() grab(in)) match {
case Some(elem) {
aggregator = f(aggregator, elem)
pull(in)
}
case None pull(in)
}
override def onPull(): Unit = {
if (isClosed(in)) {
push(out, aggregator)
completeStage()
} else {
pull(in)
}
}
override def onUpstreamFinish(): Unit = {
if (isAvailable(out)) {
push(out, aggregator)
completeStage()
}
}
setHandlers(in, out, this)
}
}
/**
@ -415,36 +440,55 @@ final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) ext
/**
* INTERNAL API
*/
final case class Grouped[T](n: Int) extends PushPullStage[T, immutable.Seq[T]] {
private val buf = {
val b = Vector.newBuilder[T]
b.sizeHint(n)
b
}
private var left = n
final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
require(n > 0, "n must be greater than 0")
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
buf += elem
left -= 1
if (left == 0) {
val emit = buf.result()
buf.clear()
left = n
ctx.push(emit)
} else ctx.pull()
val in = Inlet[T]("Grouped.in")
val out = Outlet[immutable.Seq[T]]("Grouped.out")
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
override protected val initialAttributes: Attributes = DefaultAttributes.grouped
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
private val buf = {
val b = Vector.newBuilder[T]
b.sizeHint(n)
b
}
var left = n
override def onPush(): Unit = {
buf += grab(in)
left -= 1
if (left == 0) {
val elements = buf.result()
buf.clear()
left = n
push(out, elements)
} else {
pull(in)
}
}
override def onPull(): Unit = {
pull(in)
}
override def onUpstreamFinish(): Unit = {
// This means the buf is filled with some elements but not enough (left < n) to group together.
// Since the upstream has finished we have to push them to downstream though.
if (left < n) {
val elements = buf.result()
buf.clear()
left = n
push(out, elements)
}
completeStage()
}
setHandlers(in, out, this)
}
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
if (ctx.isFinishing) {
val elem = buf.result()
buf.clear()
left = n
ctx.pushAndFinish(elem)
} else ctx.pull()
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
if (left == n) ctx.finish()
else ctx.absorbTermination()
}
/**
@ -482,34 +526,59 @@ final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStag
/**
* INTERNAL API
*/
final case class Sliding[T](n: Int, step: Int) extends PushPullStage[T, immutable.Seq[T]] {
private var buf = Vector.empty[T]
final case class Sliding[T](n: Int, step: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
require(n > 0, "n must be greater than 0")
require(step > 0, "step must be greater than 0")
override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = {
buf :+= elem
if (buf.size < n) {
ctx.pull()
} else if (buf.size == n) {
ctx.push(buf)
} else if (step > n) {
if (buf.size == step)
buf = Vector.empty
ctx.pull()
} else {
buf = buf.drop(step)
if (buf.size == n) ctx.push(buf)
else ctx.pull()
val in = Inlet[T]("Sliding.in")
val out = Outlet[immutable.Seq[T]]("Sliding.out")
override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out)
override protected val initialAttributes: Attributes = DefaultAttributes.sliding
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
private var buf = Vector.empty[T]
override def onPush(): Unit = {
buf :+= grab(in)
if (buf.size < n) {
pull(in)
} else if (buf.size == n) {
push(out, buf)
} else if (step <= n) {
buf = buf.drop(step)
if (buf.size == n) {
push(out, buf)
} else pull(in)
} else if (step > n) {
if (buf.size == step) {
buf = buf.drop(step)
}
pull(in)
}
}
override def onPull(): Unit = {
pull(in)
}
override def onUpstreamFinish(): Unit = {
// We can finish current stage directly if:
// 1. the buf is empty or
// 2. when the step size is greater than the sliding size (step > n) and current stage is in between
// two sliding (ie. buf.size >= n && buf.size < step).
//
// Otherwise it means there is still a not finished sliding so we have to push them before finish current stage.
if (buf.size < n && buf.size > 0) {
push(out, buf)
}
completeStage()
}
this.setHandlers(in, out, this)
}
override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective =
if (!ctx.isFinishing) ctx.pull()
else if (buf.size >= n) ctx.finish()
else ctx.pushAndFinish(buf)
override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective =
if (buf.isEmpty) ctx.finish()
else ctx.absorbTermination()
}
/**

View file

@ -672,7 +672,7 @@ trait FlowOps[+Out, +Mat] {
*
* '''Cancels when''' downstream cancels
*/
def grouped(n: Int): Repr[immutable.Seq[Out]] = andThen(Grouped(n))
def grouped(n: Int): Repr[immutable.Seq[Out]] = via(Grouped(n))
/**
* Ensure stream boundedness by limiting the number of elements from upstream.
@ -736,7 +736,7 @@ trait FlowOps[+Out, +Mat] {
*
* '''Cancels when''' downstream cancels
*/
def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = andThen(Sliding(n, step))
def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = via(Sliding(n, step))
/**
* Similar to `fold` but is not a terminal operation,
@ -777,7 +777,7 @@ trait FlowOps[+Out, +Mat] {
*
* See also [[FlowOps.scan]]
*/
def fold[T](zero: T)(f: (T, Out) T): Repr[T] = andThen(Fold(zero, f))
def fold[T](zero: T)(f: (T, Out) T): Repr[T] = via(Fold(zero, f))
/**
* Similar to `fold` but uses first element as zero element.