diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index 7785010014..c0e3535394 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -302,27 +302,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(3))) } - "restart when Collect throws" in { - // TODO can't get type inference to work with `pf` inlined - val pf: PartialFunction[Int, Int] = - { case x: Int ⇒ if (x == 0) throw TE else x } - new OneBoundedSetup[Int](Seq( - Collect(pf, restartingDecider))) { - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(2) - lastEvents() should be(Set(OnNext(2))) - - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(0) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(3) - lastEvents() should be(Set(OnNext(3))) - } - } - "resume when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { downstream.requestOne() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowCollectSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowCollectSpec.scala index d9198236d3..717d23391d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowCollectSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowCollectSpec.scala @@ -3,15 +3,24 @@ */ package akka.stream.scaladsl +import akka.stream.ActorAttributes._ +import akka.stream.Supervision._ +import akka.stream.impl.ConstantFun +import akka.stream.testkit.Utils.TE +import akka.stream.testkit.scaladsl.TestSink + import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -import akka.stream.ActorMaterializerSettings +import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } import akka.testkit.AkkaSpec -import akka.stream.testkit.ScriptedTest +import akka.stream.testkit.{ TestSubscriber, ScriptedTest } + +import scala.util.control.NoStackTrace class FlowCollectSpec extends AkkaSpec with ScriptedTest { val settings = ActorMaterializerSettings(system) + implicit val materializer = ActorMaterializer(settings) "A Collect" must { @@ -23,6 +32,19 @@ class FlowCollectSpec extends AkkaSpec with ScriptedTest { TestConfig.RandomTestRange foreach (_ ⇒ runScript(script, settings)(_.collect { case x if x % 2 == 0 ⇒ (x * x).toString })) } + "restart when Collect throws" in { + val pf: PartialFunction[Int, Int] = + { case x: Int ⇒ if (x == 2) throw TE("") else x } + Source(1 to 3).collect(pf).withAttributes(supervisionStrategy(restartingDecider)) + .runWith(TestSink.probe[Int]) + .request(1) + .expectNext(1) + .request(1) + .expectNext(3) + .request(1) + .expectComplete() + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowDropWhileSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowDropWhileSpec.scala index 7415a53e5b..d6ffbe1486 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowDropWhileSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowDropWhileSpec.scala @@ -35,13 +35,24 @@ class FlowDropWhileSpec extends AkkaSpec { } "continue if error" in assertAllStagesStopped { - val testException = new Exception("test") with NoStackTrace - Source(1 to 4).dropWhile(a ⇒ if (a < 3) true else throw testException).withAttributes(supervisionStrategy(resumingDecider)) + Source(1 to 4).dropWhile(a ⇒ if (a < 3) true else throw TE("")).withAttributes(supervisionStrategy(resumingDecider)) .runWith(TestSink.probe[Int]) .request(1) .expectComplete() } + "restart with strategy" in assertAllStagesStopped { + Source(1 to 4).dropWhile { + case 1 | 3 ⇒ true + case 4 ⇒ false + case 2 ⇒ throw TE("") + }.withAttributes(supervisionStrategy(restartingDecider)) + .runWith(TestSink.probe[Int]) + .request(1) + .expectNext(4) + .expectComplete() + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index a171ad7161..8a916d0e6d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -155,10 +155,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr)) } - final case class Collect[In, Out](pf: PartialFunction[In, Out], attributes: Attributes = collect) extends SymbolicStage[In, Out] { - override def create(attr: Attributes): Stage[In, Out] = fusing.Collect(pf, supervision(attr)) - } - final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) } @@ -168,10 +164,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) } - final case class LimitWeighted[T](max: Long, weightFn: T ⇒ Long, attributes: Attributes = limitWeighted) extends SymbolicStage[T, T] { - override def create(attr: Attributes): Stage[T, T] = fusing.LimitWeighted(max, weightFn) - } - final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") require(step > 0, "step must be greater than 0") @@ -183,10 +175,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr)) } - final case class DropWhile[T](p: T ⇒ Boolean, attributes: Attributes = dropWhile) extends SymbolicStage[T, T] { - override def create(attr: Attributes): Stage[T, T] = fusing.DropWhile(p, supervision(attr)) - } - final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 6fb58721ed..fd8dcbca79 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -59,38 +59,89 @@ private[akka] final case class TakeWhile[T](p: T ⇒ Boolean, decider: Supervisi /** * INTERNAL API */ -private[akka] final case class DropWhile[T](p: T ⇒ Boolean, decider: Supervision.Decider) extends PushStage[T, T] { - var taking = false +private[stream] final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T, T]] { + val in = Inlet[T]("DropWhile.in") + val out = Outlet[T]("DropWhile.out") + override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.dropWhile - override def onPush(elem: T, ctx: Context[T]): SyncDirective = - if (taking || !p(elem)) { - taking = true - ctx.push(elem) - } else { - ctx.pull() + def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + override def onPush(): Unit = { + val elem = grab(in) + withSupervision(() ⇒ p(elem)) match { + case Some(flag) if flag ⇒ pull(in) + case Some(flag) if !flag ⇒ + push(out, elem) + setHandler(in, rest) + case None ⇒ // do nothing + } } - override def decide(t: Throwable): Supervision.Directive = decider(t) + def rest = new InHandler { + def onPush() = push(out, grab(in)) + } + + override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) + } + override def toString = "DropWhile" } -private[akka] object Collect { +/** + * INTERNAL API + */ +abstract private[stream] class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) { + private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) + def withSupervision[T](f: () ⇒ T): Option[T] = + try { Some(f()) } catch { + case NonFatal(ex) ⇒ + decider(ex) match { + case Supervision.Stop ⇒ onStop(ex) + case Supervision.Resume ⇒ onResume(ex) + case Supervision.Restart ⇒ onRestart(ex) + } + None + } + + def onResume(t: Throwable): Unit + def onStop(t: Throwable): Unit = failStage(t) + def onRestart(t: Throwable): Unit = onResume(t) +} + +private[stream] object Collect { // Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once, // and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not. // Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458 final val NotApplied: Any ⇒ Any = _ ⇒ Collect.NotApplied } -private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out], decider: Supervision.Decider) extends PushStage[In, Out] { +/** + * INTERNAL API + */ +private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphStage[FlowShape[In, Out]] { + val in = Inlet[In]("Collect.in") + val out = Outlet[Out]("Collect.out") + override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.collect - import Collect.NotApplied + def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + import Collect.NotApplied + val wrappedPf = () ⇒ pf.applyOrElse(grab(in), NotApplied) - override def onPush(elem: In, ctx: Context[Out]): SyncDirective = - pf.applyOrElse(elem, NotApplied) match { - case NotApplied ⇒ ctx.pull() - case result: Out @unchecked ⇒ ctx.push(result) + override def onPush(): Unit = withSupervision(wrappedPf) match { + case Some(result) ⇒ result match { + case NotApplied ⇒ pull(in) + case result: Out @unchecked ⇒ push(out, result) + } + case None ⇒ //do nothing } - override def decide(t: Throwable): Supervision.Directive = decider(t) + override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) + } + override def toString = "Collect" } /** @@ -312,15 +363,33 @@ private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immut /** * INTERNAL API */ +private[stream] final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStage[FlowShape[T, T]] { + val in = Inlet[T]("LimitWeighted.in") + val out = Outlet[T]("LimitWeighted.out") + override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.limitWeighted -private[akka] final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends PushStage[T, T] { - private var left = n + def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + private var left = n - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - left -= costFn(elem) - if (left >= 0) ctx.push(elem) - else ctx.fail(new StreamLimitReachedException(n)) + override def onPush(): Unit = { + val elem = grab(in) + withSupervision(() ⇒ costFn(elem)) match { + case Some(wight) ⇒ + left -= wight + if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n)) + case None ⇒ //do nothing + } + } + override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onRestart(t: Throwable): Unit = { + left = n + if (!hasBeenPulled(in)) pull(in) + } + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) } + override def toString = "LimitWeighted" } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index e041eaa743..12a30cb809 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -627,7 +627,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def dropWhile(p: Out ⇒ Boolean): Repr[Out] = andThen(DropWhile(p)) + def dropWhile(p: Out ⇒ Boolean): Repr[Out] = via(DropWhile(p)) /** * Transform this stream by applying the given partial function to each of the elements @@ -642,7 +642,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def collect[T](pf: PartialFunction[Out, T]): Repr[T] = andThen(Collect(pf)) + def collect[T](pf: PartialFunction[Out, T]): Repr[T] = via(Collect(pf)) /** * Chunk up this stream into groups of the given size, with the last group @@ -705,7 +705,7 @@ trait FlowOps[+Out, +Mat] { * * See also [[FlowOps.take]], [[FlowOps.takeWithin]], [[FlowOps.takeWhile]] */ - def limitWeighted[T](max: Long)(costFn: Out ⇒ Long): Repr[Out] = andThen(LimitWeighted(max, costFn)) + def limitWeighted[T](max: Long)(costFn: Out ⇒ Long): Repr[Out] = via(LimitWeighted(max, costFn)) /** * Apply a sliding window over the stream and return the windows as groups of elements, with the last group diff --git a/project/MiMa.scala b/project/MiMa.scala index 2af118b654..00cb48c540 100644 --- a/project/MiMa.scala +++ b/project/MiMa.scala @@ -754,7 +754,18 @@ object MiMa extends AutoPlugin { // #20028 Simplify TickSource cancellation ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$TickSourceCancellable"), - ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$") + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$"), + + // #19834 replacing PushStages usages with GraphStages + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect$"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted$"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile$"), + FilterAnyProblemStartingWith("akka.stream.impl.fusing.Collect"), + FilterAnyProblemStartingWith("akka.stream.impl.fusing.DropWhile"), + FilterAnyProblemStartingWith("akka.stream.impl.fusing.LimitWeighted") ) ) }