diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 9aa9127b91..637fd3516b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -79,9 +79,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement one-to-many many-to-one chain correctly" in new OneBoundedSetup[Int](Seq( - Doubler(), - Filter((x: Int) ⇒ x != 0, stoppingDecider))) { + "implement one-to-many many-to-one chain correctly" in new OneBoundedSetup[Int]( + Doubler().toGS, + Filter((x: Int) ⇒ x != 0)) { lastEvents() should be(Set.empty) @@ -104,9 +104,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement many-to-one one-to-many chain correctly" in new OneBoundedSetup[Int](Seq( - Filter((x: Int) ⇒ x != 0, stoppingDecider), - Doubler())) { + "implement many-to-one one-to-many chain correctly" in new OneBoundedSetup[Int]( + Filter((x: Int) ⇒ x != 0), + Doubler().toGS) { lastEvents() should be(Set.empty) @@ -147,7 +147,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { } "implement take inside a chain" in new OneBoundedSetup[Int]( - Filter((x: Int) ⇒ x != 0, stoppingDecider).toGS, + Filter((x: Int) ⇒ x != 0), takeTwo, Map((x: Int) ⇒ x + 1, stoppingDecider).toGS) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index c0e3535394..90e1fd0a96 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -4,7 +4,9 @@ package akka.stream.impl.fusing import scala.util.control.NoStackTrace +import akka.stream.ActorAttributes._ import akka.stream.Supervision +import akka.stream.Supervision._ import akka.stream.stage.Context import akka.stream.stage.PushPullStage import akka.stream.stage.Stage @@ -286,22 +288,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit { } } - "resume when Filter throws" in new OneBoundedSetup[Int](Seq( - Filter((x: Int) ⇒ if (x == 0) throw TE else true, resumingDecider))) { - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(2) - lastEvents() should be(Set(OnNext(2))) - - downstream.requestOne() - lastEvents() should be(Set(RequestOne)) - upstream.onNext(0) // boom - lastEvents() should be(Set(RequestOne)) - - upstream.onNext(3) - lastEvents() should be(Set(OnNext(3))) - } - "resume when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { downstream.requestOne() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala index a907a525af..486ed1b6b4 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala @@ -3,17 +3,26 @@ */ package akka.stream.scaladsl +import akka.stream.testkit.scaladsl.TestSink + import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } +import akka.stream.ActorAttributes._ +import akka.stream.Supervision._ +import akka.stream.testkit.Utils._ import akka.stream.ActorMaterializer import akka.stream.ActorMaterializerSettings import akka.stream.testkit._ import akka.testkit.AkkaSpec +import scala.util.control.NoStackTrace + class FlowFilterSpec extends AkkaSpec with ScriptedTest { val settings = ActorMaterializerSettings(system) .withInputBuffer(initialSize = 2, maxSize = 16) + implicit val materializer = ActorMaterializer(settings) + "A Filter" must { "filter" in { @@ -38,6 +47,18 @@ class FlowFilterSpec extends AkkaSpec with ScriptedTest { probe.expectComplete() } + "continue if error" in assertAllStagesStopped { + val TE = new Exception("TEST") with NoStackTrace { + override def toString = "TE" + } + + Source(1 to 3).filter((x: Int) ⇒ if (x == 2) throw TE else true).withAttributes(supervisionStrategy(resumingDecider)) + .runWith(TestSink.probe[Int]) + .request(3) + .expectNext(1, 3) + .expectComplete() + } + } "A FilterNot" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 2252d46faf..38689642c0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -153,10 +153,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, T] = fusing.Log(name, extract, loggingAdapter, supervision(attr)) } - final case class Filter[T](p: T ⇒ Boolean, attributes: Attributes = filter) extends SymbolicStage[T, T] { - override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr)) - } - final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index befbafc0be..858d2c4985 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -34,12 +34,37 @@ private[akka] final case class Map[In, Out](f: In ⇒ Out, decider: Supervision. /** * INTERNAL API */ -private[akka] final case class Filter[T](p: T ⇒ Boolean, decider: Supervision.Decider) extends PushStage[T, T] { - override def onPush(elem: T, ctx: Context[T]): SyncDirective = - if (p(elem)) ctx.push(elem) - else ctx.pull() +private[akka] final case class Filter[T](p: T ⇒ Boolean) extends SimpleLinearGraphStage[T] { + override def initialAttributes: Attributes = DefaultAttributes.filter - override def decide(t: Throwable): Supervision.Directive = decider(t) + override def toString: String = "Filter" + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler with InHandler { + override def toString = "FilterLogic" + + def decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) + + override def onPush(): Unit = { + try { + val elem = grab(in) + if (p(elem)) { + push(out, elem) + } else { + pull(in) + } + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case _ ⇒ pull(in) + } + } + } + + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 4917350b07..4538b813a2 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -581,7 +581,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def filter(p: Out ⇒ Boolean): Repr[Out] = andThen(Filter(p)) + def filter(p: Out ⇒ Boolean): Repr[Out] = via(Filter(p)) /** * Only pass on those elements that NOT satisfy the given predicate.