From 9683e4bc58e92d378a983998678126b85981064f Mon Sep 17 00:00:00 2001 From: zhxiaog Date: Fri, 8 Jul 2016 14:22:18 +0200 Subject: [PATCH] migrate Fold, Sliding, Grouped to GraphStage (#20914) --- .../stream/impl/fusing/InterpreterSpec.scala | 15 +- .../fusing/InterpreterSupervisionSpec.scala | 16 +- .../fusing/LifecycleInterpreterSpec.scala | 2 +- .../main/scala/akka/stream/impl/Stages.scala | 16 -- .../scala/akka/stream/impl/fusing/Ops.scala | 199 ++++++++++++------ .../scala/akka/stream/scaladsl/Flow.scala | 6 +- 6 files changed, 154 insertions(+), 100 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 8855f0f43b..d4f8afec9e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -169,7 +169,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(Cancel, OnComplete, OnNext(3))) } - "implement fold" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -188,7 +188,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(3), OnComplete)) } - "implement fold with proper cancel" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold with proper cancel" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set.empty) @@ -208,7 +208,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "work if fold completes while not in a push position" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "work if fold completes while not in a push position" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set.empty) @@ -219,7 +219,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(0))) } - "implement grouped" in new OneBoundedSetup[Int](Seq(Grouped(3))) { + "implement grouped" in new OneBoundedSetup[Int](Grouped(3)) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -490,9 +490,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(1), OnComplete)) } - "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](Seq( - new PushFinishStage, - Fold(0, (x: Int, y: Int) ⇒ x + y, stoppingDecider))) { + "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int]( + (new PushFinishStage).toGS, + Fold(0, (x: Int, y: Int) ⇒ x + y)) { lastEvents() should be(Set.empty) @@ -618,6 +618,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { push(out, lastElem) } + // note that the default value of lastElem will be always pushed if the upstream closed at the very begining without a pulling override def onPull(): Unit = { if (isClosed(in)) { push(out, lastElem) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index ec86f8e6bc..c4233ca47d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -100,10 +100,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(114))) } - "resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( - Map((x: Int) ⇒ x + 1, resumingDecider), - Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), - Grouped(3))) { + "resume when Map throws before Grouped" in new OneBoundedSetup[Int]( + Map((x: Int) ⇒ x + 1, resumingDecider).toGS, + Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider).toGS, + Grouped(3)) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -120,10 +120,10 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(13, 14, 15)))) } - "complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( - Map((x: Int) ⇒ x + 1, resumingDecider), - Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), - Grouped(1000))) { + "complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int]( + Map((x: Int) ⇒ x + 1, resumingDecider).toGS, + Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider).toGS, + Grouped(1000)) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala index 8fed034313..be7225a3da 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala @@ -128,7 +128,7 @@ class LifecycleInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { "postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[String]( new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"), - Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider).toGS) { + Fold("", (x: String, y: String) ⇒ x + y)) { lastEvents() should be(Set.empty) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index a3d9ed72da..c5c8371cd6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -160,22 +160,6 @@ object Stages { override def create(attr: Attributes): Stage[In, Out] = fusing.Map(f, supervision(attr)) } - final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] { - require(n > 0, "n must be greater than 0") - override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) - } - - final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] { - require(n > 0, "n must be greater than 0") - require(step > 0, "step must be greater than 0") - - override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) - } - - final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { - override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) - } - final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] { require(size > 0, s"Buffer size must be larger than zero but was [$size]") override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 413915914d..6fcbbb7c00 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -8,8 +8,8 @@ import akka.event.{ LogSource, Logging, LoggingAdapter } import akka.stream.Attributes.{ InputBuffer, LogLevels } import akka.stream.OverflowStrategies._ import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage -import akka.stream.impl.{ Buffer ⇒ BufferImpl, ReactiveStreamsCompliance } -import akka.stream.scaladsl.Source +import akka.stream.impl.{ Buffer ⇒ BufferImpl, Stages, ReactiveStreamsCompliance } +import akka.stream.scaladsl.{ SourceQueue, Source } import akka.stream.stage._ import akka.stream.{ Supervision, _ } import scala.annotation.tailrec @@ -347,23 +347,48 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta /** * INTERNAL API */ -final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] { - private[this] var aggregator: Out = zero +final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphStage[FlowShape[In, Out]] { - override def onPush(elem: In, ctx: Context[Out]): SyncDirective = { - aggregator = f(aggregator, elem) - ctx.pull() - } + val in = Inlet[In]("Fold.in") + val out = Outlet[Out]("Fold.out") + override val shape: FlowShape[In, Out] = FlowShape(in, out) - override def onPull(ctx: Context[Out]): SyncDirective = - if (ctx.isFinishing) ctx.pushAndFinish(aggregator) - else ctx.pull() + override val initialAttributes = DefaultAttributes.fold - override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination() + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + private var aggregator: Out = zero - override def decide(t: Throwable): Supervision.Directive = decider(t) + override def onResume(t: Throwable): Unit = { + aggregator = zero + } - override def restart(): Fold[In, Out] = copy() + override def onPush(): Unit = withSupervision(() ⇒ grab(in)) match { + case Some(elem) ⇒ { + aggregator = f(aggregator, elem) + pull(in) + } + case None ⇒ pull(in) + } + + override def onPull(): Unit = { + if (isClosed(in)) { + push(out, aggregator) + completeStage() + } else { + pull(in) + } + } + + override def onUpstreamFinish(): Unit = { + if (isAvailable(out)) { + push(out, aggregator) + completeStage() + } + } + + setHandlers(in, out, this) + } } /** @@ -415,36 +440,55 @@ final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) ext /** * INTERNAL API */ -final case class Grouped[T](n: Int) extends PushPullStage[T, immutable.Seq[T]] { - private val buf = { - val b = Vector.newBuilder[T] - b.sizeHint(n) - b - } - private var left = n +final case class Grouped[T](n: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] { + require(n > 0, "n must be greater than 0") - override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = { - buf += elem - left -= 1 - if (left == 0) { - val emit = buf.result() - buf.clear() - left = n - ctx.push(emit) - } else ctx.pull() + val in = Inlet[T]("Grouped.in") + val out = Outlet[immutable.Seq[T]]("Grouped.out") + override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out) + + override protected val initialAttributes: Attributes = DefaultAttributes.grouped + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + private val buf = { + val b = Vector.newBuilder[T] + b.sizeHint(n) + b + } + var left = n + + override def onPush(): Unit = { + buf += grab(in) + left -= 1 + if (left == 0) { + val elements = buf.result() + buf.clear() + left = n + push(out, elements) + } else { + pull(in) + } + } + + override def onPull(): Unit = { + pull(in) + } + + override def onUpstreamFinish(): Unit = { + // This means the buf is filled with some elements but not enough (left < n) to group together. + // Since the upstream has finished we have to push them to downstream though. + if (left < n) { + val elements = buf.result() + buf.clear() + left = n + push(out, elements) + } + completeStage() + } + + setHandlers(in, out, this) } - override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective = - if (ctx.isFinishing) { - val elem = buf.result() - buf.clear() - left = n - ctx.pushAndFinish(elem) - } else ctx.pull() - - override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective = - if (left == n) ctx.finish() - else ctx.absorbTermination() } /** @@ -482,34 +526,59 @@ final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStag /** * INTERNAL API */ -final case class Sliding[T](n: Int, step: Int) extends PushPullStage[T, immutable.Seq[T]] { - private var buf = Vector.empty[T] +final case class Sliding[T](n: Int, step: Int) extends GraphStage[FlowShape[T, immutable.Seq[T]]] { + require(n > 0, "n must be greater than 0") + require(step > 0, "step must be greater than 0") - override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): SyncDirective = { - buf :+= elem - if (buf.size < n) { - ctx.pull() - } else if (buf.size == n) { - ctx.push(buf) - } else if (step > n) { - if (buf.size == step) - buf = Vector.empty - ctx.pull() - } else { - buf = buf.drop(step) - if (buf.size == n) ctx.push(buf) - else ctx.pull() + val in = Inlet[T]("Sliding.in") + val out = Outlet[immutable.Seq[T]]("Sliding.out") + override val shape: FlowShape[T, immutable.Seq[T]] = FlowShape(in, out) + + override protected val initialAttributes: Attributes = DefaultAttributes.sliding + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + private var buf = Vector.empty[T] + + override def onPush(): Unit = { + buf :+= grab(in) + if (buf.size < n) { + pull(in) + } else if (buf.size == n) { + push(out, buf) + } else if (step <= n) { + buf = buf.drop(step) + if (buf.size == n) { + push(out, buf) + } else pull(in) + } else if (step > n) { + if (buf.size == step) { + buf = buf.drop(step) + } + pull(in) + } } + + override def onPull(): Unit = { + pull(in) + } + + override def onUpstreamFinish(): Unit = { + + // We can finish current stage directly if: + // 1. the buf is empty or + // 2. when the step size is greater than the sliding size (step > n) and current stage is in between + // two sliding (ie. buf.size >= n && buf.size < step). + // + // Otherwise it means there is still a not finished sliding so we have to push them before finish current stage. + if (buf.size < n && buf.size > 0) { + push(out, buf) + } + completeStage() + } + + this.setHandlers(in, out, this) } - override def onPull(ctx: Context[immutable.Seq[T]]): SyncDirective = - if (!ctx.isFinishing) ctx.pull() - else if (buf.size >= n) ctx.finish() - else ctx.pushAndFinish(buf) - - override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective = - if (buf.isEmpty) ctx.finish() - else ctx.absorbTermination() } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 5e0aad72f8..c0ee5ef4e0 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -672,7 +672,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def grouped(n: Int): Repr[immutable.Seq[Out]] = andThen(Grouped(n)) + def grouped(n: Int): Repr[immutable.Seq[Out]] = via(Grouped(n)) /** * Ensure stream boundedness by limiting the number of elements from upstream. @@ -736,7 +736,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = andThen(Sliding(n, step)) + def sliding(n: Int, step: Int = 1): Repr[immutable.Seq[Out]] = via(Sliding(n, step)) /** * Similar to `fold` but is not a terminal operation, @@ -777,7 +777,7 @@ trait FlowOps[+Out, +Mat] { * * See also [[FlowOps.scan]] */ - def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = andThen(Fold(zero, f)) + def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = via(Fold(zero, f)) /** * Similar to `fold` but uses first element as zero element.