diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 9f235c3b1c..5a581597f2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -457,7 +457,7 @@ import scala.util.control.NonFatal shell: GraphInterpreterShell, logic: GraphStageLogic, evt: Any, - promise: OptionVal[Promise[Done]], + promise: Promise[Done], handler: (Any) ⇒ Unit) extends BoundaryEvent { override def execute(eventLimit: Int): Int = { if (!waitingForShutdown) { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index f0fa43f950..8999564635 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -194,7 +194,7 @@ import scala.util.control.NonFatal val log: LoggingAdapter, val logics: Array[GraphStageLogic], // Array of stage logics val connections: Array[GraphInterpreter.Connection], - val onAsyncInput: (GraphStageLogic, Any, OptionVal[Promise[Done]], (Any) ⇒ Unit) ⇒ Unit, + val onAsyncInput: (GraphStageLogic, Any, Promise[Done], (Any) ⇒ Unit) ⇒ Unit, val fuzzingMode: Boolean, val context: ActorRef) { @@ -435,7 +435,7 @@ import scala.util.control.NonFatal eventsRemaining } - def runAsyncInput(logic: GraphStageLogic, evt: Any, promise: OptionVal[Promise[Done]], handler: (Any) ⇒ Unit): Unit = + def runAsyncInput(logic: GraphStageLogic, evt: Any, promise: Promise[Done], handler: (Any) ⇒ Unit): Unit = if (!isStageCompleted(logic)) { if (GraphInterpreter.Debug) println(s"$Name ASYNC $evt ($handler) [$logic]") val currentInterpreterHolder = _currentInterpreter.get() @@ -445,10 +445,16 @@ import scala.util.control.NonFatal activeStage = logic try { handler(evt) - if (promise.isDefined) promise.get.success(Done) + if (promise ne GraphStageLogic.NoPromise) { + promise.success(Done) + logic.onFeedbackDispatched() + } } catch { case NonFatal(ex) ⇒ - if (promise.isDefined) promise.get.failure(ex) + if (promise ne GraphStageLogic.NoPromise) { + promise.failure(ex) + logic.onFeedbackDispatched() + } logic.failStage(ex) } afterStageHasRun(logic) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 422b5e2543..420806a97b 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -239,6 +239,14 @@ object GraphStageLogic { object StageActorRef { type Receive = ((ActorRef, Any)) ⇒ Unit } + + /** + * Internal API + * + * Marker value to pass to onAsyncInput if no promise was supplied. + */ + @InternalApi + private[stream] val NoPromise: Promise[Done] = Promise.successful(Done) } /** @@ -1037,143 +1045,75 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private final class ConcurrentAsyncCallback[T](handler: T ⇒ Unit) extends AsyncCallback[T] { sealed trait State - // waiting for materialization completion - private case class Pending(pendingEvents: List[Event]) extends State - // GraphStage sending all events to stream - private case object Initializing extends State + // waiting for materialization completion or during dispatching of initially queued events + private final case class Pending(pendingEvents: List[Event]) extends State // stream is initialized and so no threads can just send events without any synchronization overhead private case object Initialized extends State // Event with feedback promise - private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]]) + private final case class Event(e: T, handlingPromise: Promise[Done]) - private[this] val currentState = new AtomicReference[State](Pending(Nil)) + private[this] val NoPendingEvents = Pending(Nil) + private[this] val currentState = new AtomicReference[State](NoPendingEvents) // is called from the owning [[GraphStage]] @tailrec private[stage] def onStart(): Unit = { - (currentState.getAndSet(Initializing): @unchecked) match { - case Pending(l) ⇒ l.reverse.foreach(evt ⇒ { - evt.handlingPromise match { - case OptionVal.Some(p) ⇒ p.future.onComplete(_ ⇒ onFeedbackCompleted(p))(ExecutionContexts.sameThreadExecutionContext) - case OptionVal.None ⇒ // buffered invoke without promise - } - onAsyncInput(evt.e, evt.handlingPromise) - }) + // dispatch callbacks that have been queued before the interpreter was started + (currentState.getAndSet(NoPendingEvents): @unchecked) match { + case Pending(l) ⇒ if (l.nonEmpty) l.reverse.foreach(evt ⇒ onAsyncInput(evt.e, evt.handlingPromise)) + case s ⇒ throw new IllegalStateException(s"Unexpected callback state [$s]") } - if (!currentState.compareAndSet(Initializing, Initialized)) { - (currentState.get: @unchecked) match { - case Pending(_) ⇒ onStart() - } - } - } - // is called from the owning [[GraphStage]] - private[stage] def onStop(outstandingPromises: Set[Promise[Done]]): Unit = { - if (outstandingPromises.nonEmpty) { - val detachedException = streamDetatchedException - val iterator = outstandingPromises.iterator - while (iterator.hasNext) { - iterator.next().tryFailure(detachedException) - } - } - } - - private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = { - interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) - } - - private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = { - onAsyncInput(event, OptionVal.Some(promise)) - if (stopped) failPromiseOnComplete(promise) - else promise + // in the meantime more callbacks might have been queued (we keep queueing them to ensure order) + if (!currentState.compareAndSet(NoPendingEvents, Initialized)) + // state guaranteed to be still Pending + onStart() } // external call override def invokeWithFeedback(event: T): Future[Done] = { val promise: Promise[Done] = Promise[Done]() - promise.future.onComplete(_ ⇒ onFeedbackCompleted(promise))(ExecutionContexts.sameThreadExecutionContext) + /** + * Add this promise to the owning logic, so it can be completed afterPostStop if it was never handled otherwise. + * Returns whether the logic is still running. + */ @tailrec def addToWaiting(): Boolean = { val previous = asyncCallbacksInProgress.get() - if (previous != null) { - val updated = previous.updated(this, previous(this) + promise) + if (previous != null) { // not stopped + val updated = promise :: previous if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting() else true - } else { + } else // logic was already stopped false - } } - if (addToWaiting()) - invokeWithPromise(event, promise).future - else + if (addToWaiting()) { + invokeWithPromise(event, promise) + promise.future + } else Future.failed(streamDetatchedException) } - // removes the promise from the callbacks in promise on complete, called from onComplete - private def onFeedbackCompleted(promise: Promise[Done]): Unit = { - @tailrec - def removeFromWaiting(): Unit = { - val previous = asyncCallbacksInProgress.get() - if (previous != null) { - val newSet = previous(this) - promise - val updated = - if (newSet.isEmpty) previous - this // no outstanding promises, remove stage from map to avoid leak - else previous.updated(this, newSet) - if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) removeFromWaiting() - } - } - removeFromWaiting() - } + //external call + override def invoke(event: T): Unit = invokeWithPromise(event, NoPromise) @tailrec - private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = + private def invokeWithPromise(event: T, promise: Promise[Done]): Unit = currentState.get() match { - // started - can just send message to stream - case Initialized ⇒ sendEvent(event, promise) - // not started yet - case list @ Pending(_) ⇒ - if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents))) - invokeWithPromise(event, promise) // atomicity is failed - try again - else promise - // initializing is in progress in another thread (initializing thread is managed by Akka) - case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) { - (currentState.get(): @unchecked) match { - case Pending(_) ⇒ invokeWithPromise(event, promise) // atomicity is failed - try again - case Initialized ⇒ sendEvent(event, promise) - } - } else promise + case Initialized ⇒ + // started - can just dispatch async message to interpreter + onAsyncInput(event, promise) + + case list @ Pending(l) ⇒ + // not started yet + if (!currentState.compareAndSet(list, Pending(Event(event, promise) :: l))) + invokeWithPromise(event, promise) } - private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { - onFeedbackCompleted(promise) - promise.tryFailure(streamDetatchedException) - promise - } - - //external call - override def invoke(event: T): Unit = { - @tailrec - def internalInvoke(event: T): Unit = currentState.get() match { - // started - can just send message to stream - case Initialized ⇒ onAsyncInput(event, OptionVal.None) - // not started yet - case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) - // initializing is in progress in another thread (initializing thread is managed by akka) - case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal.None) :: Nil))) { - (currentState.get(): @unchecked) match { - case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) - case Initialized ⇒ onAsyncInput(event, OptionVal.None) - } - } - } - internalInvoke(event) - } - - private def streamDetatchedException = - new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed") - + private def onAsyncInput(event: T, promise: Promise[Done]): Unit = + interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) } /** @@ -1193,8 +1133,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private var callbacksWaitingForInterpreter: List[ConcurrentAsyncCallback[_]] = Nil // is used for two purposes: keep track of running callbacks and signal that the // stage has stopped to fail incoming async callback invocations by being set to null - private val asyncCallbacksInProgress = - new AtomicReference(Map.empty[ConcurrentAsyncCallback[_], Set[Promise[Done]]].withDefaultValue(Set.empty)) + private val asyncCallbacksInProgress = new AtomicReference[List[Promise[Done]]](Nil) private def stopped = asyncCallbacksInProgress.get() == null @@ -1252,9 +1191,39 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // make sure any invokeWithFeedback after this fails fast // and fail current outstanding invokeWithFeedback promises val inProgress = asyncCallbacksInProgress.getAndSet(null) - inProgress.foreach { case (acb, promises) ⇒ acb.onStop(promises) } + if (inProgress.nonEmpty) { + val exception = streamDetatchedException + inProgress.foreach(_.tryFailure(exception)) + } } + private[this] var asyncCleanupCounter = 0L + + /** Called from interpreter thread by GraphInterpreter.runAsyncInput */ + private[stream] def onFeedbackDispatched(): Unit = { + asyncCleanupCounter += 1 + + // 256 seemed to be a sweet spot in SendQueueBenchmark.queue benchmarks + // It means that at most 255 completed promises are retained per logic that + // uses invokeWithFeedback callbacks. + // + // TODO: add periodical cleanup to get rid of those 255 promises as well + if (asyncCleanupCounter % 256 == 0) { + @tailrec def cleanup(): Unit = { + val previous = asyncCallbacksInProgress.get() + if (previous != null) { + val updated = previous.filterNot(_.isCompleted) + if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) cleanup() + } + } + + cleanup() + } + } + + private def streamDetatchedException = + new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed") + /** * Invoked before any external events are processed, at the startup of the stage. */