From 63ceb52bbd5e23a77ebd2bad051d167b67745296 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Wed, 19 Sep 2018 11:06:24 +0200 Subject: [PATCH] Use supervision in all places of Source.fromIterator, #25574 (#25601) * it was noticed in Source.fromIterator depending on where the iterator throwed exception * fromIterator, as many other things, is implemented with statefulMapConcat * supervision was only used for exceptions in onPush and not in onPull, added it there also --- .../stream/scaladsl/FlowIteratorSpec.scala | 4 +-- .../akka/stream/scaladsl/SourceSpec.scala | 31 ++++++++++++++++- .../scala/akka/stream/impl/fusing/Ops.scala | 34 ++++++++++++------- 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIteratorSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIteratorSpec.scala index 744a022e38..b8f56aa7cd 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIteratorSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIteratorSpec.scala @@ -39,9 +39,7 @@ class FlowIterableSpec extends AbstractFlowIteratorSpec { sub.request(1) c.expectNext(1) c.expectNoMsg(100.millis) - EventFilter[IllegalStateException](message = "not two", occurrences = 1).intercept { - sub.request(2) - } + sub.request(2) c.expectError().getMessage should be("not two") sub.request(2) c.expectNoMsg(100.millis) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index a46a06a214..b6d8e07040 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -6,9 +6,10 @@ package akka.stream.scaladsl import akka.testkit.DefaultTimeout import org.scalatest.time.{ Millis, Span } - import scala.concurrent.{ Await, Future } import scala.concurrent.duration._ + +import akka.stream.testkit.Utils.TE //#imports import akka.stream._ @@ -266,6 +267,34 @@ class SourceSpec extends StreamSpec with DefaultTimeout { .runWith(Sink.head) .futureValue should ===(immutable.Seq(false, true, false, true, false, true, false, true, false, true)) } + + "fail stream when iterator throws" in { + Source + .fromIterator(() ⇒ (1 to 1000).toIterator.map(k ⇒ if (k < 10) k else throw TE("a"))) + .runWith(Sink.ignore) + .failed.futureValue.getClass should ===(classOf[TE]) + + Source + .fromIterator(() ⇒ (1 to 1000).toIterator.map(_ ⇒ throw TE("b"))) + .runWith(Sink.ignore) + .failed.futureValue.getClass should ===(classOf[TE]) + } + + "use decider when iterator throws" in { + Source + .fromIterator(() ⇒ (1 to 5).toIterator.map(k ⇒ if (k != 3) k else throw TE("a"))) + .withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .grouped(10) + .runWith(Sink.head) + .futureValue should ===(List(1, 2)) + + Source + .fromIterator(() ⇒ (1 to 5).toIterator.map(_ ⇒ throw TE("b"))) + .withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .grouped(10) + .runWith(Sink.headOption) + .futureValue should ===(None) + } } "ZipN Source" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 43462d52e8..aab1d48785 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -18,16 +18,17 @@ import akka.stream.impl.{ ReactiveStreamsCompliance, Buffer ⇒ BufferImpl } import akka.stream.scaladsl.{ Flow, Keep, Source } import akka.stream.stage._ import akka.stream.{ Supervision, _ } - import scala.annotation.tailrec import scala.collection.immutable import scala.collection.immutable.VectorBuilder import scala.concurrent.{ Future, Promise } import scala.util.control.{ NoStackTrace, NonFatal } import scala.util.{ Failure, Success, Try } -import akka.stream.ActorAttributes.SupervisionStrategy +import akka.stream.ActorAttributes.SupervisionStrategy import scala.concurrent.duration.{ FiniteDuration, _ } +import scala.util.control.Exception.Catcher + import akka.stream.impl.Stages.DefaultAttributes import akka.util.OptionVal @@ -1938,19 +1939,28 @@ private[stream] object Collect { try { currentIterator = plainFun(grab(in)).iterator pushPull() - } catch { - case NonFatal(ex) ⇒ decider(ex) match { - case Supervision.Stop ⇒ failStage(ex) - case Supervision.Resume ⇒ if (!hasBeenPulled(in)) pull(in) - case Supervision.Restart ⇒ - restartState() - if (!hasBeenPulled(in)) pull(in) - } - } + } catch handleException override def onUpstreamFinish(): Unit = onFinish() - override def onPull(): Unit = pushPull() + override def onPull(): Unit = + try pushPull() + catch handleException + + private def handleException: Catcher[Unit] = { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case Supervision.Resume ⇒ + if (isClosed(in)) completeStage() + else if (!hasBeenPulled(in)) pull(in) + case Supervision.Restart ⇒ + if (isClosed(in)) completeStage() + else { + restartState() + if (!hasBeenPulled(in)) pull(in) + } + } + } private def restartState(): Unit = { plainFun = f()