diff --git a/akka-stream/src/main/scala/akka/stream/impl/ContextPropagation.scala b/akka-stream/src/main/scala/akka/stream/impl/ContextPropagation.scala index 4cdca71742..9b459bc41d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ContextPropagation.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ContextPropagation.scala @@ -26,11 +26,6 @@ import akka.annotation.InternalApi } private[akka] final class ContextPropagationImpl extends ContextPropagation { - private val buffer = Buffer[Unit](1, 1) - def suspendContext(): Unit = { - buffer.enqueue(()) - } - def resumeContext(): Unit = { - buffer.dequeue() - } + def suspendContext(): Unit = () + def resumeContext(): Unit = () } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 2096d0d0e6..d93b060332 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -2212,15 +2212,20 @@ private[akka] final class StatefulMapConcat[In, Out](val f: () => In => Iterable lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider var currentIterator: Iterator[Out] = _ var plainFun = f() + val contextPropagation = ContextPropagation() def hasNext = if (currentIterator != null) currentIterator.hasNext else false setHandlers(in, out, this) - def pushPull(): Unit = + def pushPull(shouldResumeContext: Boolean): Unit = if (hasNext) { + if (shouldResumeContext) contextPropagation.resumeContext() push(out, currentIterator.next()) - if (!hasNext && isClosed(in)) completeStage() + if (hasNext) { + // suspend context for the next element + contextPropagation.suspendContext() + } else if (isClosed(in)) completeStage() } else if (!isClosed(in)) pull(in) else completeStage() @@ -2230,13 +2235,13 @@ private[akka] final class StatefulMapConcat[In, Out](val f: () => In => Iterable override def onPush(): Unit = try { currentIterator = plainFun(grab(in)).iterator - pushPull() + pushPull(shouldResumeContext = false) } catch handleException override def onUpstreamFinish(): Unit = onFinish() override def onPull(): Unit = - try pushPull() + try pushPull(shouldResumeContext = true) catch handleException private def handleException: Catcher[Unit] = {