diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index a0299b90d0..da05c2e2c4 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -3,6 +3,7 @@ */ package akka.stream.impl.fusing +import akka.stream.impl.ConstantFun import akka.stream.stage._ import akka.stream.testkit.AkkaSpec import akka.testkit.EventFilter @@ -241,10 +242,11 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(3)), OnComplete)) } - "implement conflate" in new OneBoundedSetup[Int](Seq(Conflate( + "implement batch (conflate)" in new OneBoundedSetup[Int](Batch( + 1L, + ConstantFun.zeroLong, (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ agg + x, - stoppingDecider))) { + (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set(RequestOne)) @@ -299,15 +301,17 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "work with conflate-conflate" in new OneBoundedSetup[Int](Seq( - Conflate( + "work with batch-batch (conflate-conflate)" in new OneBoundedSetup[Int]( + Batch( + 1L, + ConstantFun.zeroLong, (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ agg + x, - stoppingDecider), - Conflate( + (agg: Int, x: Int) ⇒ agg + x), + Batch( + 1L, + ConstantFun.zeroLong, (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ agg + x, - stoppingDecider))) { + (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set(RequestOne)) @@ -370,11 +374,12 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(12))) } - "implement conflate-expand" in new OneBoundedSetup[Int]( - Conflate( + "implement batch-expand (conflate-expand)" in new OneBoundedSetup[Int]( + Batch( + 1L, + ConstantFun.zeroLong, (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ agg + x, - stoppingDecider).toGS, + (agg: Int, x: Int) ⇒ agg + x), new Expand(Iterator.continually(_: Int))) { lastEvents() should be(Set(RequestOne)) @@ -404,12 +409,13 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement doubler-conflate" in new OneBoundedSetup[Int](Seq( - Doubler(), - Conflate( + "implement doubler-conflate (doubler-batch)" in new OneBoundedSetup[Int]( + Doubler().toGS, + Batch( + 1L, + ConstantFun.zeroLong, (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ agg + x, - stoppingDecider))) { + (agg: Int, x: Int) ⇒ agg + x)) { lastEvents() should be(Set(RequestOne)) upstream.onNext(1) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala index f443c112de..13209fcb3e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala @@ -4,6 +4,7 @@ package akka.stream.impl.fusing import akka.NotUsed +import akka.stream.impl.ConstantFun import akka.stream.{ Attributes, Shape, Supervision } import akka.stream.stage.AbstractStage.PushPullGraphStage import akka.stream.stage.GraphStageWithMaterializedValue @@ -100,25 +101,26 @@ class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { } - "work with a massive chain of conflates by overflowing to the heap" in new OneBoundedSetup[Int](Vector.fill(chainLength / 10)(Conflate( - (in: Int) ⇒ in, - (agg: Int, in: Int) ⇒ agg + in, - Supervision.stoppingDecider))) { + "work with a massive chain of batches by overflowing to the heap" in { - lastEvents() should be(Set(RequestOne)) + val batch = Batch( + 0L, + ConstantFun.zeroLong, + (in: Int) ⇒ in, + (agg: Int, in: Int) ⇒ agg + in) + + new OneBoundedSetup[Int](Vector.fill(chainLength / 10)(batch): _*) { - var i = 0 - while (i < repetition) { - upstream.onNext(1) lastEvents() should be(Set(RequestOne)) - i += 1 + + var i = 0 + while (i < repetition) { + upstream.onNext(1) + lastEvents() should be(Set(RequestOne)) + i += 1 + } } - - downstream.requestOne() - lastEvents() should be(Set(OnNext(repetition))) - } - } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index e98dd9693d..90f08118e4 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -387,70 +387,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20 } - "restart when Conflate `seed` throws" in new OneBoundedSetup[Int](Seq(Conflate( - (in: Int) ⇒ if (in == 1) throw TE else in, - (agg: Int, x: Int) ⇒ agg + x, - restartingDecider))) { - - lastEvents() should be(Set(RequestOne)) - - downstream.requestOne() - lastEvents() should be(Set.empty) - - upstream.onNext(0) - lastEvents() should be(Set(OnNext(0), RequestOne)) - - upstream.onNext(1) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(2) - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(10) - lastEvents() should be(Set(RequestOne)) - - downstream.requestOne() - lastEvents() should be(Set(OnNext(12))) // note that 1 has been discarded - - downstream.requestOne() - lastEvents() should be(Set.empty) - } - - "restart when Conflate `aggregate` throws" in new OneBoundedSetup[Int](Seq(Conflate( - (in: Int) ⇒ in, - (agg: Int, x: Int) ⇒ if (x == 2) throw TE else agg + x, - restartingDecider))) { - - lastEvents() should be(Set(RequestOne)) - - downstream.requestOne() - lastEvents() should be(Set.empty) - - upstream.onNext(0) - lastEvents() should be(Set(OnNext(0), RequestOne)) - - upstream.onNext(1) - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(2) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(10) - lastEvents() should be(Set(RequestOne)) - - downstream.requestOne() - lastEvents() should be(Set(OnNext(10))) // note that 1 and 2 has been discarded - - downstream.requestOne() - lastEvents() should be(Set.empty) - - upstream.onNext(4) - lastEvents() should be(Set(OnNext(4), RequestOne)) - - downstream.cancel() - lastEvents() should be(Set(Cancel)) - } - "fail when Expand `seed` throws" in new OneBoundedSetup[Int]( new Expand((in: Int) ⇒ if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConflateSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConflateSpec.scala index 9d951c5fb6..152bd2d8c4 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConflateSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConflateSpec.scala @@ -3,10 +3,15 @@ */ package akka.stream.scaladsl +import akka.stream.ActorAttributes.supervisionStrategy +import akka.stream.Attributes.inputBuffer +import akka.stream.Supervision.{ resumingDecider, restartingDecider } +import akka.stream.testkit.Utils.TE + import scala.concurrent.Await import scala.concurrent.duration._ import scala.concurrent.forkjoin.ThreadLocalRandom -import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings } +import akka.stream._ import akka.stream.testkit._ class FlowConflateSpec extends AkkaSpec { @@ -135,6 +140,103 @@ class FlowConflateSpec extends AkkaSpec { Await.result(future, 3.seconds) should be((1 to 50).sum) } + "restart when `seed` throws and a restartingDecider is used" in { + val sourceProbe = TestPublisher.probe[Int]() + val sinkProbe = TestSubscriber.probe[Int]() + + val future = Source.fromPublisher(sourceProbe) + .conflateWithSeed(seed = i ⇒ + if (i % 2 == 0) throw TE("I hate even seed numbers") else i)(aggregate = (sum, i) ⇒ sum + i) + .withAttributes(supervisionStrategy(restartingDecider)) + .to(Sink.fromSubscriber(sinkProbe)) + .withAttributes(inputBuffer(initial = 1, max = 1)) + .run() + + val sub = sourceProbe.expectSubscription() + val sinkSub = sinkProbe.expectSubscription() + // push the first value + sub.expectRequest(1) + sub.sendNext(1) + + // and consume it, so that the next element + // will trigger seed + sinkSub.request(1) + sinkProbe.expectNext(1) + + sub.expectRequest(1) + sub.sendNext(2) + sub.expectRequest(1) + sub.sendNext(3) + + // now we should have lost the 2 and the accumulated state + sinkSub.request(1) + sinkProbe.expectNext(3) + } + + "restart when `aggregate` throws and a restartingDecider is used" in { + val conflate = Flow[String] + .conflateWithSeed(seed = i ⇒ i)((state, elem) ⇒ + if (elem == "two") throw TE("two is a three letter word") + else state + elem) + .withAttributes(supervisionStrategy(restartingDecider)) + + val sourceProbe = TestPublisher.probe[String]() + val sinkProbe = TestSubscriber.probe[String]() + + Source.fromPublisher(sourceProbe) + .via(conflate) + .to(Sink.fromSubscriber(sinkProbe)) + .withAttributes(inputBuffer(initial = 4, max = 4)) + .run() + + val sub = sourceProbe.expectSubscription() + sub.expectRequest(4) + sub.sendNext("one") + sub.sendNext("two") + sub.sendNext("three") + sub.sendComplete() + + // "one" should be lost + sinkProbe.requestNext() shouldEqual ("three") + + } + + "resume when `aggregate` throws and a resumingDecider is used" in { + + val sourceProbe = TestPublisher.probe[Int]() + val sinkProbe = TestSubscriber.probe[Vector[Int]]() + + val future = Source.fromPublisher(sourceProbe) + .conflateWithSeed(seed = i ⇒ Vector(i))((state, elem) ⇒ + if (elem == 2) throw TE("three is a four letter word") else state :+ elem) + .withAttributes(supervisionStrategy(resumingDecider)) + .to(Sink.fromSubscriber(sinkProbe)) + .withAttributes(inputBuffer(initial = 1, max = 1)) + .run() + + val sub = sourceProbe.expectSubscription() + val sinkSub = sinkProbe.expectSubscription() + // push the first three values, the third will trigger + // the exception + sub.expectRequest(1) + sub.sendNext(1) + + // causing the 1 to get thrown away + sub.expectRequest(1) + sub.sendNext(2) + + sub.expectRequest(1) + sub.sendNext(3) + + sub.expectRequest(1) + sub.sendNext(4) + + // and consume it, so that the next element + // will trigger seed + sinkSub.request(1) + sinkProbe.expectNext(Vector(1, 3, 4)) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index f9881314de..771664d742 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -205,10 +205,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy) } - final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out, attributes: Attributes = conflate) extends SymbolicStage[In, Out] { - override def create(attr: Attributes): Stage[In, Out] = fusing.Conflate(seed, aggregate, supervision(attr)) - } - final case class MapConcat[In, Out](f: In ⇒ immutable.Iterable[Out], attributes: Attributes = mapConcat) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.MapConcat(f, supervision(attr)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index a2706ce4d0..669bff29c6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -432,50 +432,6 @@ private[akka] final case class Completed[T]() extends PushPullStage[T, T] { override def onPull(ctx: Context[T]): SyncDirective = ctx.finish() } -/** - * INTERNAL API - */ -private[akka] final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out, - decider: Supervision.Decider) extends DetachedStage[In, Out] { - private var agg: Any = null - - override def onPush(elem: In, ctx: DetachedContext[Out]): UpstreamDirective = { - agg = - if (agg == null) seed(elem) - else aggregate(agg.asInstanceOf[Out], elem) - - if (!ctx.isHoldingDownstream) ctx.pull() - else { - val result = agg.asInstanceOf[Out] - agg = null - ctx.pushAndPull(result) - } - } - - override def onPull(ctx: DetachedContext[Out]): DownstreamDirective = { - if (ctx.isFinishing) { - if (agg == null) ctx.finish() - else { - val result = agg.asInstanceOf[Out] - agg = null - ctx.pushAndFinish(result) - } - } else if (agg == null) ctx.holdDownstream() - else { - val result = agg.asInstanceOf[Out] - if (result == null) throw new NullPointerException - agg = null - ctx.push(result) - } - } - - override def onUpstreamFinish(ctx: DetachedContext[Out]): TerminationDirective = ctx.absorbTermination() - - override def decide(t: Throwable): Supervision.Directive = decider(t) - - override def restart(): Conflate[In, Out] = copy() -} - /** * INTERNAL API */ @@ -488,6 +444,9 @@ private[akka] final case class Batch[In, Out](max: Long, costFn: In ⇒ Long, se override val shape: FlowShape[In, Out] = FlowShape.of(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + + val decider = inheritedAttributes.getAttribute(classOf[SupervisionStrategy]).map(_.decider).getOrElse(Supervision.stoppingDecider) + private var agg: Out = null.asInstanceOf[Out] private var left: Long = max private var pending: In = null.asInstanceOf[In] @@ -496,10 +455,18 @@ private[akka] final case class Batch[In, Out](max: Long, costFn: In ⇒ Long, se push(out, agg) left = max if (pending != null) { - val elem = pending - agg = seed(elem) - left -= costFn(elem) - pending = null.asInstanceOf[In] + try { + agg = seed(pending) + left -= costFn(pending) + pending = null.asInstanceOf[In] + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case Supervision.Restart ⇒ restartState() + case Supervision.Resume ⇒ + pending = null.asInstanceOf[In] + } + } } else { agg = null.asInstanceOf[Out] } @@ -512,14 +479,33 @@ private[akka] final case class Batch[In, Out](max: Long, costFn: In ⇒ Long, se override def onPush(): Unit = { val elem = grab(in) val cost = costFn(elem) + if (agg == null) { - left -= cost - agg = seed(elem) + try { + agg = seed(elem) + left -= cost + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case Supervision.Restart ⇒ + restartState() + case Supervision.Resume ⇒ + } + } } else if (left < cost) { pending = elem } else { - left -= cost - agg = aggregate(agg, elem) + try { + agg = aggregate(agg, elem) + left -= cost + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case Supervision.Restart ⇒ + restartState() + case Supervision.Resume ⇒ + } + } } if (isAvailable(out)) flush() @@ -541,15 +527,32 @@ private[akka] final case class Batch[In, Out](max: Long, costFn: In ⇒ Long, se push(out, agg) if (pending == null) completeStage() else { - agg = seed(pending) + try { + agg = seed(pending) + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case Supervision.Resume ⇒ + case Supervision.Restart ⇒ + restartState() + if (!hasBeenPulled(in)) pull(in) + } + } pending = null.asInstanceOf[In] } } else { flush() if (!hasBeenPulled(in)) pull(in) } + } }) + + private def restartState(): Unit = { + agg = null.asInstanceOf[Out] + left = max + pending = null.asInstanceOf[In] + } } } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index fa2956f215..ee3deacd7b 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -934,9 +934,8 @@ trait FlowOps[+Out, +Mat] { * * See also [[FlowOps.conflate]], [[FlowOps.limit]], [[FlowOps.limitWeighted]] [[FlowOps.batch]] [[FlowOps.batchWeighted]] */ - def conflateWithSeed[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = andThen(Conflate(seed, aggregate)) - //FIXME: conflate can be expressed as a batch - //via(Batch(1L, ConstantFun.zeroLong, seed, aggregate).withAttributes(DefaultAttributes.conflate)) + def conflateWithSeed[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = + via(Batch(1L, ConstantFun.zeroLong, seed, aggregate).withAttributes(DefaultAttributes.conflate))