diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 9f235c3b1c..5a581597f2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -457,7 +457,7 @@ import scala.util.control.NonFatal shell: GraphInterpreterShell, logic: GraphStageLogic, evt: Any, - promise: OptionVal[Promise[Done]], + promise: Promise[Done], handler: (Any) ⇒ Unit) extends BoundaryEvent { override def execute(eventLimit: Int): Int = { if (!waitingForShutdown) { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index cd206e7bbe..8999564635 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -194,7 +194,7 @@ import scala.util.control.NonFatal val log: LoggingAdapter, val logics: Array[GraphStageLogic], // Array of stage logics val connections: Array[GraphInterpreter.Connection], - val onAsyncInput: (GraphStageLogic, Any, OptionVal[Promise[Done]], (Any) ⇒ Unit) ⇒ Unit, + val onAsyncInput: (GraphStageLogic, Any, Promise[Done], (Any) ⇒ Unit) ⇒ Unit, val fuzzingMode: Boolean, val context: ActorRef) { @@ -435,7 +435,7 @@ import scala.util.control.NonFatal eventsRemaining } - def runAsyncInput(logic: GraphStageLogic, evt: Any, promise: OptionVal[Promise[Done]], handler: (Any) ⇒ Unit): Unit = + def runAsyncInput(logic: GraphStageLogic, evt: Any, promise: Promise[Done], handler: (Any) ⇒ Unit): Unit = if (!isStageCompleted(logic)) { if (GraphInterpreter.Debug) println(s"$Name ASYNC $evt ($handler) [$logic]") val currentInterpreterHolder = _currentInterpreter.get() @@ -445,16 +445,14 @@ import scala.util.control.NonFatal activeStage = logic try { handler(evt) - if (promise.isDefined) { - val p = promise.get - p.success(Done) + if (promise ne GraphStageLogic.NoPromise) { + promise.success(Done) logic.onFeedbackDispatched() } } catch { case NonFatal(ex) ⇒ - if (promise.isDefined) { - val p = promise.get - promise.get.failure(ex) + if (promise ne GraphStageLogic.NoPromise) { + promise.failure(ex) logic.onFeedbackDispatched() } logic.failStage(ex) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 63f89595cc..420806a97b 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -239,6 +239,14 @@ object GraphStageLogic { object StageActorRef { type Receive = ((ActorRef, Any)) ⇒ Unit } + + /** + * Internal API + * + * Marker value to pass to onAsyncInput if no promise was supplied. + */ + @InternalApi + private[stream] val NoPromise: Promise[Done] = Promise.successful(Done) } /** @@ -1037,105 +1045,75 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private final class ConcurrentAsyncCallback[T](handler: T ⇒ Unit) extends AsyncCallback[T] { sealed trait State - // waiting for materialization completion - private case class Pending(pendingEvents: List[Event]) extends State - // GraphStage sending all events to stream - private case object Initializing extends State + // waiting for materialization completion or during dispatching of initially queued events + private final case class Pending(pendingEvents: List[Event]) extends State // stream is initialized and so no threads can just send events without any synchronization overhead private case object Initialized extends State // Event with feedback promise - private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]]) + private final case class Event(e: T, handlingPromise: Promise[Done]) - private[this] val currentState = new AtomicReference[State](Pending(Nil)) + private[this] val NoPendingEvents = Pending(Nil) + private[this] val currentState = new AtomicReference[State](NoPendingEvents) // is called from the owning [[GraphStage]] @tailrec private[stage] def onStart(): Unit = { - (currentState.getAndSet(Initializing): @unchecked) match { - case Pending(l) ⇒ l.reverse.foreach(evt ⇒ onAsyncInput(evt.e, evt.handlingPromise)) + // dispatch callbacks that have been queued before the interpreter was started + (currentState.getAndSet(NoPendingEvents): @unchecked) match { + case Pending(l) ⇒ if (l.nonEmpty) l.reverse.foreach(evt ⇒ onAsyncInput(evt.e, evt.handlingPromise)) + case s ⇒ throw new IllegalStateException(s"Unexpected callback state [$s]") } - if (!currentState.compareAndSet(Initializing, Initialized)) { - (currentState.get: @unchecked) match { - case Pending(_) ⇒ onStart() - } - } - } - private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = { - interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) - } - - private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = { - onAsyncInput(event, OptionVal.Some(promise)) - if (stopped) failPromiseOnComplete(promise) - else promise + // in the meantime more callbacks might have been queued (we keep queueing them to ensure order) + if (!currentState.compareAndSet(NoPendingEvents, Initialized)) + // state guaranteed to be still Pending + onStart() } // external call override def invokeWithFeedback(event: T): Future[Done] = { val promise: Promise[Done] = Promise[Done]() + /** + * Add this promise to the owning logic, so it can be completed afterPostStop if it was never handled otherwise. + * Returns whether the logic is still running. + */ @tailrec def addToWaiting(): Boolean = { val previous = asyncCallbacksInProgress.get() - if (previous != null) { + if (previous != null) { // not stopped val updated = promise :: previous if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting() else true - } else { + } else // logic was already stopped false - } } - if (addToWaiting()) - invokeWithPromise(event, promise).future - else + if (addToWaiting()) { + invokeWithPromise(event, promise) + promise.future + } else Future.failed(streamDetatchedException) } - @tailrec - private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = - currentState.get() match { - // started - can just send message to stream - case Initialized ⇒ sendEvent(event, promise) - // not started yet - case list @ Pending(_) ⇒ - if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents))) - invokeWithPromise(event, promise) // atomicity is failed - try again - else promise - // initializing is in progress in another thread (initializing thread is managed by Akka) - case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) { - (currentState.get(): @unchecked) match { - case Pending(_) ⇒ invokeWithPromise(event, promise) // atomicity is failed - try again - case Initialized ⇒ sendEvent(event, promise) - } - } else promise - } - - private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { - promise.tryFailure(streamDetatchedException) - promise - } - //external call - override def invoke(event: T): Unit = { - @tailrec - def internalInvoke(event: T): Unit = currentState.get() match { - // started - can just send message to stream - case Initialized ⇒ onAsyncInput(event, OptionVal.None) - // not started yet - case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) - // initializing is in progress in another thread (initializing thread is managed by akka) - case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal.None) :: Nil))) { - (currentState.get(): @unchecked) match { - case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) - case Initialized ⇒ onAsyncInput(event, OptionVal.None) - } - } - } - internalInvoke(event) - } + override def invoke(event: T): Unit = invokeWithPromise(event, NoPromise) + @tailrec + private def invokeWithPromise(event: T, promise: Promise[Done]): Unit = + currentState.get() match { + case Initialized ⇒ + // started - can just dispatch async message to interpreter + onAsyncInput(event, promise) + + case list @ Pending(l) ⇒ + // not started yet + if (!currentState.compareAndSet(list, Pending(Event(event, promise) :: l))) + invokeWithPromise(event, promise) + } + + private def onAsyncInput(event: T, promise: Promise[Done]): Unit = + interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) } /**