diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala index 9d4aab4a66..1d5df9dc82 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala @@ -477,13 +477,13 @@ import pekko.stream.stage._ handler(evt) if (promise ne GraphStageLogic.NoPromise) { promise.success(Done) - logic.onFeedbackDispatched() + logic.onFeedbackDispatched(promise) } } catch { case NonFatal(ex) => if (promise ne GraphStageLogic.NoPromise) { promise.failure(ex) - logic.onFeedbackDispatched() + logic.onFeedbackDispatched(promise) } logic.failStage(ex) } diff --git a/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala b/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala index 9f01667e9a..dc0f51187a 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala @@ -13,19 +13,22 @@ package org.apache.pekko.stream.stage +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicReference + +import scala.annotation.nowarn import scala.annotation.tailrec import scala.collection.{ immutable, mutable } import scala.concurrent.{ Future, Promise } import scala.concurrent.duration.FiniteDuration -import scala.annotation.nowarn + import org.apache.pekko import pekko.{ Done, NotUsed } import pekko.actor._ import pekko.annotation.InternalApi import pekko.japi.function.{ Effect, Procedure } -import pekko.stream.Attributes.SourceLocation import pekko.stream._ +import pekko.stream.Attributes.SourceLocation import pekko.stream.impl.{ ReactiveStreamsCompliance, TraversalBuilder } import pekko.stream.impl.ActorSubscriberMessage import pekko.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource } @@ -1227,13 +1230,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Add this promise to the owning logic, so it can be completed afterPostStop if it was never handled otherwise. * Returns whether the logic is still running. */ - @tailrec def addToWaiting(): Boolean = { - val previous = asyncCallbacksInProgress.get() - if (previous != null) { // not stopped - val updated = promise :: previous - if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting() - else true + val callbacks = asyncCallbacksInProgress.get() + if (callbacks ne null) { // not stopped + callbacks.add(promise) + asyncCallbacksInProgress.get ne null // logic may already stopped } else // logic was already stopped false } @@ -1282,7 +1283,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private var callbacksWaitingForInterpreter: List[ConcurrentAsyncCallback[_]] = Nil // is used for two purposes: keep track of running callbacks and signal that the // stage has stopped to fail incoming async callback invocations by being set to null - private val asyncCallbacksInProgress = new AtomicReference[List[Promise[Done]]](Nil) + // Using ConcurrentHashMap's KeySetView as Set to track the inProgress async callbacks. + private val asyncCallbacksInProgress: AtomicReference[java.util.Set[Promise[Done]]] = + new AtomicReference(ConcurrentHashMap.newKeySet()) private var _stageActor: StageActor = _ final def stageActor: StageActor = _stageActor match { @@ -1374,35 +1377,19 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } // make sure any invokeWithFeedback after this fails fast // and fail current outstanding invokeWithFeedback promises - val inProgress = asyncCallbacksInProgress.getAndSet(null) - if (inProgress.nonEmpty) { + val callbacks = asyncCallbacksInProgress.getAndSet(null) + if ((callbacks ne null) && !callbacks.isEmpty) { val exception = streamDetachedException - inProgress.foreach(_.tryFailure(exception)) + callbacks.forEach((t: Promise[Done]) => t.tryFailure(exception)) } cleanUpSubstreams(OptionVal.None) } - private[this] var asyncCleanupCounter = 0L - /** Called from interpreter thread by GraphInterpreter.runAsyncInput */ - private[stream] def onFeedbackDispatched(): Unit = { - asyncCleanupCounter += 1 - - // 256 seemed to be a sweet spot in SendQueueBenchmark.queue benchmarks - // It means that at most 255 completed promises are retained per logic that - // uses invokeWithFeedback callbacks. - // - // TODO: add periodical cleanup to get rid of those 255 promises as well - if (asyncCleanupCounter % 256 == 0) { - @tailrec def cleanup(): Unit = { - val previous = asyncCallbacksInProgress.get() - if (previous != null) { - val updated = previous.filterNot(_.isCompleted) - if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) cleanup() - } - } - - cleanup() + private[stream] def onFeedbackDispatched(promise: Promise[Done]): Unit = { + val callbacks = asyncCallbacksInProgress.get() + if (callbacks ne null) { + callbacks.remove(promise) } }