diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index f0fa43f950..cd206e7bbe 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -445,10 +445,18 @@ import scala.util.control.NonFatal activeStage = logic try { handler(evt) - if (promise.isDefined) promise.get.success(Done) + if (promise.isDefined) { + val p = promise.get + p.success(Done) + logic.onFeedbackDispatched() + } } catch { case NonFatal(ex) ⇒ - if (promise.isDefined) promise.get.failure(ex) + if (promise.isDefined) { + val p = promise.get + promise.get.failure(ex) + logic.onFeedbackDispatched() + } logic.failStage(ex) } afterStageHasRun(logic) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 422b5e2543..63f89595cc 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -1052,13 +1052,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: @tailrec private[stage] def onStart(): Unit = { (currentState.getAndSet(Initializing): @unchecked) match { - case Pending(l) ⇒ l.reverse.foreach(evt ⇒ { - evt.handlingPromise match { - case OptionVal.Some(p) ⇒ p.future.onComplete(_ ⇒ onFeedbackCompleted(p))(ExecutionContexts.sameThreadExecutionContext) - case OptionVal.None ⇒ // buffered invoke without promise - } - onAsyncInput(evt.e, evt.handlingPromise) - }) + case Pending(l) ⇒ l.reverse.foreach(evt ⇒ onAsyncInput(evt.e, evt.handlingPromise)) } if (!currentState.compareAndSet(Initializing, Initialized)) { (currentState.get: @unchecked) match { @@ -1067,17 +1061,6 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } } - // is called from the owning [[GraphStage]] - private[stage] def onStop(outstandingPromises: Set[Promise[Done]]): Unit = { - if (outstandingPromises.nonEmpty) { - val detachedException = streamDetatchedException - val iterator = outstandingPromises.iterator - while (iterator.hasNext) { - iterator.next().tryFailure(detachedException) - } - } - } - private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = { interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) } @@ -1091,13 +1074,12 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // external call override def invokeWithFeedback(event: T): Future[Done] = { val promise: Promise[Done] = Promise[Done]() - promise.future.onComplete(_ ⇒ onFeedbackCompleted(promise))(ExecutionContexts.sameThreadExecutionContext) @tailrec def addToWaiting(): Boolean = { val previous = asyncCallbacksInProgress.get() if (previous != null) { - val updated = previous.updated(this, previous(this) + promise) + val updated = promise :: previous if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting() else true } else { @@ -1111,22 +1093,6 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Future.failed(streamDetatchedException) } - // removes the promise from the callbacks in promise on complete, called from onComplete - private def onFeedbackCompleted(promise: Promise[Done]): Unit = { - @tailrec - def removeFromWaiting(): Unit = { - val previous = asyncCallbacksInProgress.get() - if (previous != null) { - val newSet = previous(this) - promise - val updated = - if (newSet.isEmpty) previous - this // no outstanding promises, remove stage from map to avoid leak - else previous.updated(this, newSet) - if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) removeFromWaiting() - } - } - removeFromWaiting() - } - @tailrec private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = currentState.get() match { @@ -1147,7 +1113,6 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { - onFeedbackCompleted(promise) promise.tryFailure(streamDetatchedException) promise } @@ -1171,9 +1136,6 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: internalInvoke(event) } - private def streamDetatchedException = - new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed") - } /** @@ -1193,8 +1155,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private var callbacksWaitingForInterpreter: List[ConcurrentAsyncCallback[_]] = Nil // is used for two purposes: keep track of running callbacks and signal that the // stage has stopped to fail incoming async callback invocations by being set to null - private val asyncCallbacksInProgress = - new AtomicReference(Map.empty[ConcurrentAsyncCallback[_], Set[Promise[Done]]].withDefaultValue(Set.empty)) + private val asyncCallbacksInProgress = new AtomicReference[List[Promise[Done]]](Nil) private def stopped = asyncCallbacksInProgress.get() == null @@ -1252,9 +1213,39 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // make sure any invokeWithFeedback after this fails fast // and fail current outstanding invokeWithFeedback promises val inProgress = asyncCallbacksInProgress.getAndSet(null) - inProgress.foreach { case (acb, promises) ⇒ acb.onStop(promises) } + if (inProgress.nonEmpty) { + val exception = streamDetatchedException + inProgress.foreach(_.tryFailure(exception)) + } } + private[this] var asyncCleanupCounter = 0L + + /** Called from interpreter thread by GraphInterpreter.runAsyncInput */ + private[stream] def onFeedbackDispatched(): Unit = { + asyncCleanupCounter += 1 + + // 256 seemed to be a sweet spot in SendQueueBenchmark.queue benchmarks + // It means that at most 255 completed promises are retained per logic that + // uses invokeWithFeedback callbacks. + // + // TODO: add periodical cleanup to get rid of those 255 promises as well + if (asyncCleanupCounter % 256 == 0) { + @tailrec def cleanup(): Unit = { + val previous = asyncCallbacksInProgress.get() + if (previous != null) { + val updated = previous.filterNot(_.isCompleted) + if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) cleanup() + } + } + + cleanup() + } + } + + private def streamDetatchedException = + new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed") + /** * Invoked before any external events are processed, at the startup of the stage. */