diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 3af76b9da1..0cf7a97685 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -9,9 +9,7 @@ import scala.annotation.tailrec import scala.collection.{ immutable, mutable } import scala.concurrent.{ Future, Promise } import scala.concurrent.duration.FiniteDuration - import scala.annotation.nowarn - import akka.{ Done, NotUsed } import akka.actor._ import akka.annotation.InternalApi @@ -22,6 +20,7 @@ import akka.stream.impl.{ ReactiveStreamsCompliance, TraversalBuilder } import akka.stream.impl.ActorSubscriberMessage import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource } import akka.stream.scaladsl.GenericGraphWithChangedAttributes +import akka.stream.stage.ConcurrentAsyncCallbackState.{ NoPendingEvents, State } import akka.util.OptionVal import akka.util.unused @@ -284,7 +283,7 @@ private[akka] object ConcurrentAsyncCallbackState { // Event with feedback promise final case class Event[+E](e: E, handlingPromise: Promise[Done]) - val NoPendingEvents = Pending[Nothing](Nil) + val NoPendingEvents: Pending[Nothing] = Pending[Nothing](Nil) } /** @@ -1186,21 +1185,22 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * [[onStop]] puts class in `Completed` state * "Real world" calls of [[invokeWithFeedback]] always return failed promises for `Completed` state */ - private final class ConcurrentAsyncCallback[T](handler: T => Unit) extends AsyncCallback[T] { + private final class ConcurrentAsyncCallback[T](handler: T => Unit) + extends AtomicReference[State[T]](NoPendingEvents) + with AsyncCallback[T] { import ConcurrentAsyncCallbackState._ - private[this] val currentState = new AtomicReference[State[T]](NoPendingEvents) // is called from the owning [[GraphStage]] @tailrec private[stage] def onStart(): Unit = { // dispatch callbacks that have been queued before the interpreter was started - (currentState.getAndSet(NoPendingEvents): @unchecked) match { + (getAndSet(NoPendingEvents): @unchecked) match { case Pending(l) => if (l.nonEmpty) l.reverse.foreach(evt => onAsyncInput(evt.e, evt.handlingPromise)) case s => throw new IllegalStateException(s"Unexpected callback state [$s]") } // in the meantime more callbacks might have been queued (we keep queueing them to ensure order) - if (!currentState.compareAndSet(NoPendingEvents, Initialized)) + if (!compareAndSet(NoPendingEvents, Initialized)) // state guaranteed to be still Pending onStart() } @@ -1236,14 +1236,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: @tailrec private def invokeWithPromise(event: T, promise: Promise[Done]): Unit = - currentState.get() match { + get() match { case Initialized => // started - can just dispatch async message to interpreter onAsyncInput(event, promise) case list @ Pending(l: List[Event[T]]) => // not started yet - if (!currentState.compareAndSet(list, Pending[T](Event[T](event, promise) :: l))) + if (!compareAndSet(list, Pending[T](Event[T](event, promise) :: l))) invokeWithPromise(event, promise) }