diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala index 6dafb8e1a8..09118ebe4c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala @@ -310,9 +310,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { abstract class OneBoundedSetup[T](_ops: GraphStageWithMaterializedValue[Shape, Any]*) extends Builder { val ops = _ops.toArray - def this(op: Seq[Stage[_, _]], dummy: Int = 42) = { - this(op.map(_.toGS): _*) - } + def this(op: Seq[Stage[_, _]], dummy: Int = 42) = this(op.map(_.toGS): _*) val upstream = new UpstreamOneBoundedProbe[T] val downstream = new DownstreamOneBoundedPortProbe[T] diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index 90e1fd0a96..cea87abe7d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -288,49 +288,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { } } - "resume when Scan throws" in new OneBoundedSetup[Int](Seq( - Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { - downstream.requestOne() - lastEvents() should be(Set(OnNext(1))) - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(2) - lastEvents() should be(Set(OnNext(3))) - - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(10) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(4) - lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4 - } - - "restart when Scan throws" in new OneBoundedSetup[Int](Seq( - Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) { - downstream.requestOne() - lastEvents() should be(Set(OnNext(1))) - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(2) - lastEvents() should be(Set(OnNext(3))) - - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(10) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(4) - lastEvents() should be(Set(OnNext(1))) // starts over again - - downstream.requestOne() - lastEvents() should be(Set(OnNext(5))) - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(20) - lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20 - } - "fail when Expand `seed` throws" in new OneBoundedSetup[Int]( new Expand((in: Int) ⇒ if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala index aed8ce7e73..69ed333e1f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala @@ -37,8 +37,7 @@ class FlowScanSpec extends AkkaSpec { } "Scan empty" in assertAllStagesStopped { - val v = Vector.empty[Int] - scan(Source(v)) should be(v.scan(0)(_ + _)) + scan(Source.empty[Int]) should be(Vector.empty[Int].scan(0)(_ + _)) } "emit values promptly" in { @@ -46,7 +45,7 @@ class FlowScanSpec extends AkkaSpec { Await.result(f, 1.second) should ===(Seq(0, 1)) } - "fail properly" in { + "restart properly" in { import ActorAttributes._ val scan = Flow[Int].scan(0) { (old, current) ⇒ require(current > 0) @@ -55,5 +54,15 @@ class FlowScanSpec extends AkkaSpec { Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe) .toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12)) } + + "resume properly" in { + import ActorAttributes._ + val scan = Flow[Int].scan(0) { (old, current) ⇒ + require(current > 0) + old + current + }.withAttributes(supervisionStrategy(Supervision.resumingDecider)) + Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe) + .toStrict(1.second) should ===(Seq(0, 1, 4, 9, 16)) + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala index 2ff48f2883..a930e68536 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala @@ -30,7 +30,7 @@ class FlowThrottleSpec extends AkkaSpec { } "accept very high rates" in Utils.assertAllStagesStopped { - Source(1 to 5).throttle(1, 1 nanos, 0, ThrottleMode.Shaping) + Source(1 to 5).throttle(1, 1.nanos, 0, ThrottleMode.Shaping) .runWith(TestSink.probe[Int]) .request(5) .expectNext(1, 2, 3, 4, 5) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index f801242014..472bd094c1 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -165,10 +165,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) } - final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] { - override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) - } - final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 3dcf9e0e2f..4539af3260 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -294,33 +294,47 @@ private[akka] final case class Drop[T](count: Long) extends SimpleLinearGraphSta /** * INTERNAL API */ -private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] { - private var aggregator = zero - private var pushedZero = false +private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphStage[FlowShape[In, Out]] { + override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out")) - override def onPush(elem: In, ctx: Context[Out]): SyncDirective = { - if (pushedZero) { - aggregator = f(aggregator, elem) - ctx.push(aggregator) - } else { - aggregator = f(zero, elem) - ctx.push(zero) + override def initialAttributes: Attributes = DefaultAttributes.scan + override def toString: String = "Scan" + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + self ⇒ + + private var aggregator = zero + private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) + + import Supervision.{ Stop, Resume, Restart } + import shape.{ in, out } + + // Initial behavior makes sure that the zero gets flushed if upstream is empty + setHandler(out, new OutHandler { + override def onPull(): Unit = { + push(out, aggregator) + setHandlers(in, out, self) + } + }) + setHandler(in, totallyIgnorantInput) + + override def onPull(): Unit = pull(in) + override def onPush(): Unit = { + try { + aggregator = f(aggregator, grab(in)) + push(out, aggregator) + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Resume ⇒ if (!hasBeenPulled(in)) pull(in) + case Stop ⇒ failStage(ex) + case Restart ⇒ + aggregator = zero + push(out, aggregator) + } + } + } } - } - - override def onPull(ctx: Context[Out]): SyncDirective = - if (!pushedZero) { - pushedZero = true - if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator) - } else ctx.pull() - - override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = - if (pushedZero) ctx.finish() - else ctx.absorbTermination() - - override def decide(t: Throwable): Supervision.Directive = decider(t) - - override def restart(): Scan[In, Out] = copy() } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index e6761c8710..bb9941d036 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -768,7 +768,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = andThen(Scan(zero, f)) + def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = via(Scan(zero, f)) /** * Similar to `scan` but only emits its result when the upstream completes,