diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala index 08ef4a56a5..b5873fb63e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala @@ -6,7 +6,7 @@ package akka.stream.scaladsl import java.util.concurrent.{ LinkedBlockingQueue, ThreadLocalRandom } import java.util.concurrent.atomic.AtomicInteger -import akka.stream.ActorAttributes.supervisionStrategy +import akka.stream.ActorAttributes.{ SupervisionStrategy, supervisionStrategy } import akka.stream.{ ActorAttributes, ActorMaterializer, Supervision } import akka.stream.Supervision.resumingDecider import akka.stream.impl.ReactiveStreamsCompliance @@ -350,6 +350,48 @@ class FlowMapAsyncSpec extends StreamSpec { } } + "not invoke the decider twice for the same failed future" in { + import system.dispatcher + val failCount = new AtomicInteger(0) + val result = Source(List(true, false)) + .mapAsync(1)(elem ⇒ + Future { + if (elem) throw TE("this has gone too far") + else elem + } + ).addAttributes(supervisionStrategy { + case TE("this has gone too far") ⇒ + failCount.incrementAndGet() + Supervision.resume + case _ ⇒ Supervision.stop + }) + .runWith(Sink.seq) + + result.futureValue should ===(Seq(false)) + failCount.get() should ===(1) + } + + "not invoke the decider twice for the same failure to produce a future" in { + import system.dispatcher + val failCount = new AtomicInteger(0) + val result = Source(List(true, false)) + .mapAsync(1)(elem ⇒ + if (elem) throw TE("this has gone too far") + else Future { + elem + } + ).addAttributes(supervisionStrategy { + case TE("this has gone too far") ⇒ + failCount.incrementAndGet() + Supervision.resume + case _ ⇒ Supervision.stop + }) + .runWith(Sink.seq) + + result.futureValue should ===(Seq(false)) + failCount.get() should ===(1) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 98a6d28dde..072a2202e5 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -20,12 +20,13 @@ import scala.annotation.tailrec import scala.collection.immutable import scala.collection.immutable.VectorBuilder import scala.concurrent.Future -import scala.util.control.NonFatal +import scala.util.control.{ NoStackTrace, NonFatal } import scala.util.{ Failure, Success, Try } import akka.stream.ActorAttributes.SupervisionStrategy import scala.concurrent.duration.{ FiniteDuration, _ } import akka.stream.impl.Stages.DefaultAttributes +import akka.util.OptionVal /** * INTERNAL API @@ -1112,7 +1113,25 @@ private[stream] object Collect { */ @InternalApi private[akka] object MapAsync { - final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] ⇒ Unit) { + final class Holder[T]( + var elem: Try[T], + val cb: AsyncCallback[Holder[T]] + ) extends (Try[T] ⇒ Unit) { + + // To support both fail-fast when the supervision directive is Stop + // and not calling the decider multiple times (#23888) we need to cache the decider result and re-use that + private var cachedSupervisionDirective: OptionVal[Supervision.Directive] = OptionVal.None + + def supervisionDirectiveFor(decider: Supervision.Decider, ex: Throwable): Supervision.Directive = { + cachedSupervisionDirective match { + case OptionVal.Some(d) ⇒ d + case OptionVal.None ⇒ + val d = decider(ex) + cachedSupervisionDirective = OptionVal.Some(d) + d + } + } + def setElem(t: Try[T]): Unit = { elem = t match { case Success(null) ⇒ Failure[T](ReactiveStreamsCompliance.elementMustNotBeNullException) @@ -1126,7 +1145,7 @@ private[stream] object Collect { } } - val NotYetThere = Failure(new Exception) + val NotYetThere = Failure(new Exception with NoStackTrace) } /** @@ -1146,36 +1165,26 @@ private[stream] object Collect { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { - override def toString = s"MapAsync.Logic(buffer=$buffer)" - //FIXME Put Supervision.stoppingDecider as a SupervisionStrategy on DefaultAttributes.mapAsync? - lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) - var buffer: BufferImpl[Holder[Out]] = _ + lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider) + .getOrElse(Supervision.stoppingDecider) - private val handleSuccessElem: PartialFunction[Try[Out], Unit] = { - case Success(elem) ⇒ - push(out, elem) - if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) - } - private val handleFailureOrPushElem: PartialFunction[Try[Out], Unit] = { - case Failure(e) if decider(e) == Supervision.Stop ⇒ failStage(e) - case _ ⇒ if (isAvailable(out)) pushOne() // skip this element - } - private def holderCompleted(holder: Holder[Out]) = handleFailureOrPushElem.apply(holder.elem) + private val futureCB = getAsyncCallback[Holder[Out]](holder ⇒ + holder.elem match { + case Success(_) ⇒ pushNextIfPossible() + case Failure(NonFatal(ex)) ⇒ + holder.supervisionDirectiveFor(decider, ex) match { + // fail fast as if supervision says so + case Supervision.Stop ⇒ failStage(ex) + case _ ⇒ pushNextIfPossible() + } + }) - val futureCB = getAsyncCallback[Holder[Out]](holderCompleted) - - private[this] def todo = buffer.used + private var buffer: BufferImpl[Holder[Out]] = _ override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer) - private def pushOne(): Unit = - if (buffer.isEmpty) { - if (isClosed(in)) completeStage() - else if (!hasBeenPulled(in)) pull(in) - } else if (buffer.peek().elem == NotYetThere) { - if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) - } else handleSuccessElem.applyOrElse(buffer.dequeue().elem, handleFailureOrPushElem) + override def onPull(): Unit = pushNextIfPossible() override def onPush(): Unit = { try { @@ -1183,24 +1192,50 @@ private[stream] object Collect { val holder = new Holder[Out](NotYetThere, futureCB) buffer.enqueue(holder) - // #20217 We dispatch the future if it's ready to optimize away - // scheduling it to an execution context future.value match { case None ⇒ future.onComplete(holder)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) case Some(v) ⇒ + // #20217 the future is already here, avoid scheduling it on the dispatcher holder.setElem(v) - handleFailureOrPushElem(v) + pushNextIfPossible() } } catch { + // this logic must only be executed if f throws, not if the future is failed case NonFatal(ex) ⇒ if (decider(ex) == Supervision.Stop) failStage(ex) } - if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) + + pullIfNeeded() } - override def onUpstreamFinish(): Unit = if (todo == 0) completeStage() + override def onUpstreamFinish(): Unit = if (buffer.isEmpty) completeStage() - override def onPull(): Unit = pushOne() + @tailrec + private def pushNextIfPossible(): Unit = + if (buffer.isEmpty) { + if (isClosed(in)) completeStage() + else pullIfNeeded() + } else if (buffer.peek().elem eq NotYetThere) pullIfNeeded() // ahead of line blocking to keep order + else if (isAvailable(out)) { + val holder = buffer.dequeue() + holder.elem match { + case Success(elem) ⇒ + push(out, elem) + pullIfNeeded() + + case Failure(NonFatal(ex)) ⇒ + holder.supervisionDirectiveFor(decider, ex) match { + case Supervision.Stop ⇒ failStage(ex) + case _ ⇒ + // try next element + pushNextIfPossible() + } + } + } + + private def pullIfNeeded(): Unit = { + if (buffer.used < parallelism && !hasBeenPulled(in)) tryPull(in) + } setHandlers(in, out, this) }