diff --git a/akka-bench-jmh/src/main/scala/akka/stream/FusedGraphsBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/FusedGraphsBenchmark.scala index ba8e12a2a0..f2dea5401f 100644 --- a/akka-bench-jmh/src/main/scala/akka/stream/FusedGraphsBenchmark.scala +++ b/akka-bench-jmh/src/main/scala/akka/stream/FusedGraphsBenchmark.scala @@ -160,6 +160,15 @@ class FusedGraphsBenchmark { chainOfMaps = fuse( testSource + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) + .map(addFunc) .map(addFunc) .toMat(testSink)(Keep.right) ) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index 9d525ec353..07327b37d3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -368,11 +368,8 @@ class GraphInterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { sink.requestOne(eventLimit = 0) source.onComplete(eventLimit = 3) - lastEvents() should ===(Set(OnNext(sink, "C"))) - - sink.requestOne() - lastEvents() should ===(Set(OnComplete(sink))) - + // OnComplete arrives early due to push chasing + lastEvents() should ===(Set(OnNext(sink, "C"), OnComplete(sink))) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index 919b3018cb..3cba149fb0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -47,6 +47,8 @@ object GraphInterpreter { final val KeepGoingFlag = 0x4000000 final val KeepGoingMask = 0x3ffffff + final val ChaseLimit = 16 + /** * Marker object that indicates that a port holds no element since it was already grabbed. The port is still pullable, * but there is no more element to grab. @@ -381,7 +383,7 @@ final class GraphInterpreter( shape.inlets.size + shape.outlets.size } - private var _subFusingMaterializer: Materializer = _ + private[this] var _subFusingMaterializer: Materializer = _ def subFusingMaterializer: Materializer = _subFusingMaterializer // An event queue implemented as a circular buffer @@ -391,6 +393,10 @@ final class GraphInterpreter( private[this] var queueHead: Int = 0 private[this] var queueTail: Int = 0 + private[this] var chaseCounter = 0 // the first events in preStart blocks should be not chased + private[this] var chasedPush: Int = NoEvent + private[this] var chasedPull: Int = NoEvent + private def queueStatus: String = { val contents = (queueHead until queueTail).map(idx ⇒ { val conn = eventQueue(idx & mask) @@ -539,20 +545,66 @@ final class GraphInterpreter( try { while (eventsRemaining > 0 && queueTail != queueHead) { val connection = dequeue() + eventsRemaining -= 1 + chaseCounter = math.min(ChaseLimit, eventsRemaining) + + def reportStageError(e: Throwable): Unit = { + if (activeStage == null) throw e + else { + val stage = assembly.stages(activeStage.stageId) + + log.error(e, "Error in stage [{}]: {}", stage, e.getMessage) + activeStage.failStage(e) + + // Abort chasing + chaseCounter = 0 + if (chasedPush != NoEvent) { + enqueue(chasedPush) + chasedPush = NoEvent + } + if (chasedPull != NoEvent) { + enqueue(chasedPull) + chasedPull = NoEvent + } + } + } + try processEvent(connection) catch { - case NonFatal(e) ⇒ - if (activeStage == null) throw e - else { - val stage = assembly.stages(activeStage.stageId) - - log.error(e, "Error in stage [{}]: {}", stage, e.getMessage) - activeStage.failStage(e) - } + case NonFatal(e) ⇒ reportStageError(e) } afterStageHasRun(activeStage) - eventsRemaining -= 1 + + // Chasing PUSH events + while (chasedPush != NoEvent) { + val connection = chasedPush + chasedPush = NoEvent + try processPush(connection) + catch { + case NonFatal(e) ⇒ reportStageError(e) + } + afterStageHasRun(activeStage) + } + + // Chasing PULL events + while (chasedPull != NoEvent) { + val connection = chasedPull + chasedPull = NoEvent + try processPull(connection) + catch { + case NonFatal(e) ⇒ reportStageError(e) + } + afterStageHasRun(activeStage) + } + + if (chasedPush != NoEvent) { + enqueue(chasedPush) + chasedPush = NoEvent + } + } + // Event *must* be enqueued while not in the execute loop (events enqueued from external, possibly async events) + chaseCounter = 0 } finally { currentInterpreterHolder(0) = previousInterpreter } @@ -577,18 +629,12 @@ final class GraphInterpreter( } finally currentInterpreterHolder(0) = previousInterpreter } + private def safeLogics(id: Int) = + if (id == Boundary) null + else logics(id) + // Decodes and processes a single event for the given connection private def processEvent(connection: Int): Unit = { - def safeLogics(id: Int) = - if (id == Boundary) null - else logics(id) - - def processElement(): Unit = { - if (Debug) println(s"$Name PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, ${connectionSlots(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") - activeStage = safeLogics(assembly.inOwners(connection)) - portStates(connection) ^= PushEndFlip - inHandlers(connection).onPush() - } // this must be the state after returning without delivering any signals, to avoid double-finalization of some unlucky stage // (this can happen if a stage completes voluntarily while connection close events are still queued) @@ -598,14 +644,11 @@ final class GraphInterpreter( // Manual fast decoding, fast paths are PUSH and PULL // PUSH if ((code & (Pushing | InClosed | OutClosed)) == Pushing) { - processElement() + processPush(connection) // PULL } else if ((code & (Pulling | OutClosed | InClosed)) == Pulling) { - if (Debug) println(s"$Name PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") - portStates(connection) ^= PullEndFlip - activeStage = safeLogics(assembly.outOwners(connection)) - outHandlers(connection).onPull() + processPull(connection) // CANCEL } else if ((code & (OutClosed | InClosed)) == InClosed) { @@ -629,13 +672,27 @@ final class GraphInterpreter( else inHandlers(connection).onUpstreamFailure(connectionSlots(connection).asInstanceOf[Failed].ex) } else { // Push is pending, first process push, then re-enqueue closing event - processElement() + processPush(connection) enqueue(connection) } } } + private def processPush(connection: Int): Unit = { + if (Debug) println(s"$Name PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, ${connectionSlots(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") + activeStage = safeLogics(assembly.inOwners(connection)) + portStates(connection) ^= PushEndFlip + inHandlers(connection).onPush() + } + + private def processPull(connection: Int): Unit = { + if (Debug) println(s"$Name PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") + activeStage = safeLogics(assembly.outOwners(connection)) + portStates(connection) ^= PullEndFlip + outHandlers(connection).onPull() + } + private def dequeue(): Int = { val idx = queueHead & mask if (fuzzingMode) { @@ -688,11 +745,31 @@ final class GraphInterpreter( } } + private[stream] def chasePush(connection: Int): Unit = { + if (chaseCounter > 0 && chasedPush == NoEvent) { + chaseCounter -= 1 + chasedPush = connection + } else enqueue(connection) + } + + private[stream] def chasePull(connection: Int): Unit = { + if (chaseCounter > 0 && chasedPull == NoEvent) { + chaseCounter -= 1 + chasedPull = connection + } else enqueue(connection) + } + private[stream] def complete(connection: Int): Unit = { val currentState = portStates(connection) if (Debug) println(s"$Name complete($connection) [$currentState]") portStates(connection) = currentState | OutClosed - if ((currentState & (InClosed | Pushing | Pulling | OutClosed)) == 0) enqueue(connection) + + // Push-Close needs special treatment, cannot be chased, convert back to ordinary event + if (chasedPush == connection) { + chasedPush = NoEvent + enqueue(connection) + } else if ((currentState & (InClosed | Pushing | Pulling | OutClosed)) == 0) enqueue(connection) + if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 4b46f87f11..1bf2d3a7ee 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -346,7 +346,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: if ((portState & (InReady | InClosed | OutClosed)) == InReady) { it.portStates(connection) = portState ^ PullStartFlip - it.enqueue(connection) + it.chasePull(connection) } else { // Detailed error information should not add overhead to the hot path require(!isClosed(in), s"Cannot pull closed port ($in)") @@ -446,7 +446,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: if ((portState & (OutReady | OutClosed | InClosed)) == OutReady && (elem != null)) { it.connectionSlots(connection) = elem - it.enqueue(connection) + it.chasePush(connection) } else { // Restore state for the error case it.portStates(connection) = portState