diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala index 35a05ab847..f2b07dbf42 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala @@ -3,27 +3,23 @@ */ package akka.stream.scaladsl -import scala.concurrent.Await -import scala.concurrent.Future -import scala.concurrent.duration._ -import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.{ LinkedBlockingQueue, ThreadLocalRandom } +import java.util.concurrent.atomic.AtomicInteger -import scala.util.control.NoStackTrace -import akka.stream.ActorMaterializer -import akka.stream.testkit._ -import akka.stream.testkit.Utils._ -import akka.testkit.TestLatch -import akka.testkit.TestProbe import akka.stream.ActorAttributes.supervisionStrategy +import akka.stream.{ ActorAttributes, ActorMaterializer, Supervision } import akka.stream.Supervision.resumingDecider import akka.stream.impl.ReactiveStreamsCompliance +import akka.stream.testkit.Utils._ +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl.TestSink +import akka.testkit.{ TestLatch, TestProbe } +import org.scalatest.concurrent.PatienceConfiguration.Timeout import scala.annotation.tailrec -import scala.concurrent.Promise -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.concurrent.PatienceConfiguration.Timeout +import scala.concurrent.{ Await, Future, Promise } +import scala.concurrent.duration._ +import scala.util.control.NoStackTrace class FlowMapAsyncSpec extends StreamSpec { @@ -122,6 +118,82 @@ class FlowMapAsyncSpec extends StreamSpec { latch.countDown() } + "a failure mid-stream MUST cause a failure ASAP (stopping strategy)" in assertAllStagesStopped { + import system.dispatcher + val pa = Promise[String]() + val pb = Promise[String]() + val pc = Promise[String]() + val pd = Promise[String]() + val pe = Promise[String]() + val pf = Promise[String]() + + val input = pa :: pb :: pc :: pd :: pe :: pf :: Nil + + val probe = Source.fromIterator(() ⇒ input.iterator) + .mapAsync(5)(p ⇒ p.future.map(_.toUpperCase)) + .runWith(TestSink.probe) + + import TestSubscriber._ + var gotErrorAlready = false + val elementOrErrorOk: PartialFunction[SubscriberEvent, Unit] = { + case OnNext("A") ⇒ () // is fine + case OnNext("B") ⇒ () // is fine + case OnError(ex) if ex.getMessage == "Boom at C" && !gotErrorAlready ⇒ + gotErrorAlready = true // fine, error can over-take elements + } + probe.request(100) + + val boom = new Exception("Boom at C") + + // placing the future completion signals here is important + // the ordering is meant to expose a race between the failure at C and subsequent elements + pa.success("a") + pb.success("b") + pc.failure(boom) + pd.success("d") + pe.success("e") + pf.success("f") + + probe.expectNextOrError() match { + case Left(ex) ⇒ ex.getMessage should ===("Boom at C") // fine, error can over-take elements + case Right("A") ⇒ + probe.expectNextOrError() match { + case Left(ex) ⇒ ex.getMessage should ===("Boom at C") // fine, error can over-take elements + case Right("B") ⇒ + probe.expectNextOrError() match { + case Left(ex) ⇒ ex.getMessage should ===("Boom at C") // fine, error can over-take elements + case Right(element) ⇒ fail(s"Got [$element] yet it caused an exception, should not have happened!") + } + } + } + } + + "a failure mid-stream must skip element with resume strategy" in assertAllStagesStopped { + val pa = Promise[String]() + val pb = Promise[String]() + val pc = Promise[String]() + val pd = Promise[String]() + val pe = Promise[String]() + val pf = Promise[String]() + + val input = pa :: pb :: pc :: pd :: pe :: pf :: Nil + + val elements = Source.fromIterator(() ⇒ input.iterator) + .mapAsync(5)(p ⇒ p.future) + .withAttributes(ActorAttributes.supervisionStrategy(Supervision.resumingDecider)) + .runWith(Sink.seq) + + // the problematic ordering: + pa.success("a") + pb.success("b") + pd.success("d") + pe.success("e") + pf.success("f") + pc.failure(new Exception("Booom!")) + + elements.futureValue should ===(List("a", "b", /* no c */ "d", "e", "f")) + } + "signal error from mapAsync" in assertAllStagesStopped { val latch = TestLatch(1) val c = TestSubscriber.manualProbe[Int]() diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 938d66016a..9e1342d9a8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -11,7 +11,7 @@ import akka.event.{ LogSource, Logging, LoggingAdapter } import akka.stream.Attributes.{ InputBuffer, LogLevels } import akka.stream.OverflowStrategies._ import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage -import akka.stream.impl.{ ReactiveStreamsCompliance, Stages, Buffer ⇒ BufferImpl } +import akka.stream.impl.{ ConstantFun, ReactiveStreamsCompliance, Stages, Buffer ⇒ BufferImpl } import akka.stream.scaladsl.{ Source, SourceQueue } import akka.stream.stage._ import akka.stream.{ Supervision, _ } @@ -1103,11 +1103,12 @@ private[stream] object Collect { @InternalApi private[akka] object MapAsync { final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] ⇒ Unit) { - def setElem(t: Try[T]): Unit = + def setElem(t: Try[T]): Unit = { elem = t match { case Success(null) ⇒ Failure[T](ReactiveStreamsCompliance.elementMustNotBeNullException) case other ⇒ other } + } override def apply(t: Try[T]): Unit = { setElem(t) @@ -1141,12 +1142,16 @@ private[stream] object Collect { lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) var buffer: BufferImpl[Holder[Out]] = _ - def holderCompleted(h: Holder[Out]): Unit = { - h.elem match { - case Failure(e) if decider(e) == Supervision.Stop ⇒ failStage(e) - case _ ⇒ if (isAvailable(out)) pushOne() - } + private val handleSuccessElem: PartialFunction[Try[Out], Unit] = { + case Success(elem) ⇒ + push(out, elem) + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) } + private val handleFailureOrPushElem: PartialFunction[Try[Out], Unit] = { + case Failure(e) if decider(e) == Supervision.Stop ⇒ failStage(e) + case _ ⇒ if (isAvailable(out)) pushOne() // skip this element + } + private def holderCompleted(holder: Holder[Out]) = handleFailureOrPushElem.apply(holder.elem) val futureCB = getAsyncCallback[Holder[Out]](holderCompleted) @@ -1154,18 +1159,13 @@ private[stream] object Collect { override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer) - @tailrec private def pushOne(): Unit = + private def pushOne(): Unit = if (buffer.isEmpty) { if (isClosed(in)) completeStage() else if (!hasBeenPulled(in)) pull(in) } else if (buffer.peek().elem == NotYetThere) { if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) - } else buffer.dequeue().elem match { - case Success(elem) ⇒ - push(out, elem) - if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) - case Failure(ex) ⇒ pushOne() - } + } else handleSuccessElem.applyOrElse(buffer.dequeue().elem, handleFailureOrPushElem) override def onPush(): Unit = { try { @@ -1179,7 +1179,7 @@ private[stream] object Collect { case None ⇒ future.onComplete(holder)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) case Some(v) ⇒ holder.setElem(v) - holderCompleted(holder) + handleFailureOrPushElem(v) } } catch {