diff --git a/akka-remote/src/main/scala/akka/remote/artery/Control.scala b/akka-remote/src/main/scala/akka/remote/artery/Control.scala index d710951ae9..61b1280568 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Control.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Control.scala @@ -73,7 +73,7 @@ private[remote] object InboundControlJunction { def notify(inboundEnvelope: InboundEnvelope): Unit } - // messages for the CallbackWrapper + // messages for the stream callback private[InboundControlJunction] sealed trait CallbackMessage private[InboundControlJunction] final case class Attach(observer: ControlMessageObserver, done: Promise[Done]) extends CallbackMessage @@ -93,8 +93,7 @@ private[remote] class InboundControlJunction override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { val stoppedPromise = Promise[Done]() - // FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way - val logic = new GraphStageLogic(shape) with CallbackWrapper[CallbackMessage] with InHandler with OutHandler { + val logic = new GraphStageLogic(shape) with InHandler with OutHandler with ControlMessageSubject { private var observers: Vector[ControlMessageObserver] = Vector.empty @@ -106,10 +105,6 @@ private[remote] class InboundControlJunction observers = observers.filterNot(_ == observer) } - override def preStart(): Unit = { - initCallback(callback.invoke) - } - override def postStop(): Unit = stoppedPromise.success(Done) // InHandler @@ -127,24 +122,22 @@ private[remote] class InboundControlJunction override def onPull(): Unit = pull(in) setHandlers(in, out, this) - } - // materialized value - val controlSubject: ControlMessageSubject = new ControlMessageSubject { + // ControlMessageSubject impl override def attach(observer: ControlMessageObserver): Future[Done] = { val p = Promise[Done]() - logic.invoke(Attach(observer, p)) + callback.invoke(Attach(observer, p)) p.future } override def detach(observer: ControlMessageObserver): Unit = - logic.invoke(Dettach(observer)) + callback.invoke(Dettach(observer)) override def stopped: Future[Done] = stoppedPromise.future } - (logic, controlSubject) + (logic, logic) } } @@ -169,18 +162,14 @@ private[remote] class OutboundControlJunction( override val shape: FlowShape[OutboundEnvelope, OutboundEnvelope] = FlowShape(in, out) override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { - // FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way - val logic = new GraphStageLogic(shape) with CallbackWrapper[ControlMessage] with InHandler with OutHandler with StageLogging { + + val logic = new GraphStageLogic(shape) with InHandler with OutHandler with StageLogging with OutboundControlIngress { import OutboundControlJunction._ - private val sendControlMessageCallback = getAsyncCallback[ControlMessage](internalSendControlMessage) + val sendControlMessageCallback = getAsyncCallback[ControlMessage](internalSendControlMessage) private val maxControlMessageBufferSize: Int = outboundContext.settings.Advanced.OutboundControlQueueSize private val buffer = new ArrayDeque[OutboundEnvelope] - override def preStart(): Unit = { - initCallback(sendControlMessageCallback.invoke) - } - // InHandler override def onPush(): Unit = { if (buffer.isEmpty && isAvailable(out)) @@ -212,16 +201,13 @@ private[remote] class OutboundControlJunction( outboundEnvelopePool.acquire().init( recipient = OptionVal.None, message = message, sender = OptionVal.None) + override def sendControlMessage(message: ControlMessage): Unit = + sendControlMessageCallback.invoke(message) + setHandlers(in, out, this) } - // materialized value - val outboundControlIngress = new OutboundControlIngress { - override def sendControlMessage(message: ControlMessage): Unit = - logic.invoke(message) - } - - (logic, outboundControlIngress) + (logic, logic) } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala index 057a8e3d98..8e104596d3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala @@ -12,9 +12,11 @@ import akka.stream.testkit.{ GraphStageMessages, StreamSpec, TestSourceStage, Te import akka.stream.testkit.scaladsl.TestSink import akka.stream.testkit.Utils._ import akka.testkit.TestProbe - import scala.concurrent.duration._ import scala.concurrent._ +import akka.Done +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl.TestSink import org.scalatest.time.Span class QueueSourceSpec extends StreamSpec { @@ -184,7 +186,7 @@ class QueueSourceSpec extends StreamSpec { expectMsgClass(classOf[Status.Failure]) } - "return false when elemen was not added to buffer" in assertAllStagesStopped { + "return false when element was not added to buffer" in assertAllStagesStopped { val s = TestSubscriber.manualProbe[Int]() val queue = Source.queue(1, OverflowStrategy.dropNew).to(Sink.fromSubscriber(s)).run() val sub = s.expectSubscription @@ -300,6 +302,15 @@ class QueueSourceSpec extends StreamSpec { .expectComplete() source.watchCompletion().futureValue should ===(Done) } + + "some elements not yet delivered to stage" in { + val (queue, probe) = + Source.queue[Unit](10, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + intercept[StreamDetachedException] { + Await.result( + (1 to 15).map(_ ⇒ queue.offer(())).last, 3.seconds) + } + } } "fail the stream" when { diff --git a/akka-stream/src/main/mima-filters/2.5.3.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.3.backwards.excludes index 9e77d2e963..4ccf532985 100644 --- a/akka-stream/src/main/mima-filters/2.5.3.backwards.excludes +++ b/akka-stream/src/main/mima-filters/2.5.3.backwards.excludes @@ -7,4 +7,4 @@ ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.MaybeSource ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.MaybeSource.attributes") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.MaybeSource.create") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.MaybeSource.this") -ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.MaybePublisher$") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.MaybePublisher$") \ No newline at end of file diff --git a/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes index 0077a82945..1049a13681 100644 --- a/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes +++ b/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes @@ -1,2 +1,20 @@ ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.GzipCompressor.this") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.DeflateCompressor.this") + +# #23111 AsyncCallbacks to just-finishing stages can be lost +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Offer") +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Completion$") +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSink$Pull") +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSink$Cancel$") +ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.QueueSink$Output") +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Failure") +ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.QueueSource$Input") +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.stage.AsyncCallback.invokeWithFeedback") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Stopped") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$NotInitialized") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Stopped$") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Initialized") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Initialized$") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$NotInitialized$") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$CallbackState") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper") \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/impl/QueueSource.scala b/akka-stream/src/main/scala/akka/stream/impl/QueueSource.scala index 4c9fefa331..f02d4f14af 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/QueueSource.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/QueueSource.scala @@ -3,26 +3,29 @@ */ package akka.stream.impl +import java.util.concurrent.CompletionStage + +import akka.Done +import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts.sameThreadExecutionContext import akka.stream.OverflowStrategies._ import akka.stream._ import akka.stream.stage._ import akka.stream.scaladsl.SourceQueueWithComplete -import akka.Done -import java.util.concurrent.CompletionStage - -import akka.annotation.InternalApi - -import scala.concurrent.{ Future, Promise } import scala.compat.java8.FutureConverters._ +import scala.concurrent.{ Future, Promise } +import scala.util.control.NonFatal /** * INTERNAL API */ @InternalApi private[akka] object QueueSource { + sealed trait Input[+T] final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T] case object Completion extends Input[Nothing] final case class Failure(ex: Throwable) extends Input[Nothing] + } /** @@ -36,22 +39,18 @@ import scala.compat.java8.FutureConverters._ override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { val completion = Promise[Done] - val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[Input[T]] with OutHandler { + + val stageLogic = new GraphStageLogic(shape) with OutHandler with SourceQueueWithComplete[T] { var buffer: Buffer[T] = _ var pendingOffer: Option[Offer[T]] = None var terminating = false override def preStart(): Unit = { if (maxBuffer > 0) buffer = Buffer(maxBuffer, materializer) - initCallback(callback.invoke) } override def postStop(): Unit = { val exception = new StreamDetachedException() completion.tryFailure(exception) - stopCallback { - case Offer(elem, promise) ⇒ promise.failure(exception) - case _ ⇒ // ignore - } } private def enqueueAndSuccess(offer: Offer[T]): Unit = { @@ -75,7 +74,7 @@ import scala.compat.java8.FutureConverters._ case DropNew ⇒ offer.promise.success(QueueOfferResult.Dropped) case Fail ⇒ - val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") + val bufferOverflowException = BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") offer.promise.success(QueueOfferResult.Failure(bufferOverflowException)) completion.failure(bufferOverflowException) failStage(bufferOverflowException) @@ -89,8 +88,7 @@ import scala.compat.java8.FutureConverters._ } } - private val callback: AsyncCallback[Input[T]] = getAsyncCallback { - + private val callback = getAsyncCallback[Input[T]] { case offer @ Offer(elem, promise) ⇒ if (maxBuffer != 0) { bufferElem(offer) @@ -107,7 +105,7 @@ import scala.compat.java8.FutureConverters._ case DropTail | DropNew ⇒ promise.success(QueueOfferResult.Dropped) case Fail ⇒ - val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") + val bufferOverflowException = BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") promise.success(QueueOfferResult.Failure(bufferOverflowException)) completion.failure(bufferOverflowException) failStage(bufferOverflowException) @@ -131,7 +129,7 @@ import scala.compat.java8.FutureConverters._ override def onDownstreamFinish(): Unit = { pendingOffer match { - case Some(Offer(elem, promise)) ⇒ + case Some(Offer(_, promise)) ⇒ promise.success(QueueOfferResult.QueueClosed) pendingOffer = None case None ⇒ // do nothing @@ -167,22 +165,22 @@ import scala.compat.java8.FutureConverters._ } } } - } - (stageLogic, new SourceQueueWithComplete[T] { + // SourceQueueWithComplete impl override def watchCompletion() = completion.future override def offer(element: T): Future[QueueOfferResult] = { val p = Promise[QueueOfferResult] - stageLogic.invoke(Offer(element, p)) + callback.invokeWithFeedback(Offer(element, p)) + .onFailure { case NonFatal(e) ⇒ p.tryFailure(e) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) p.future } - override def complete(): Unit = { - stageLogic.invoke(Completion) - } - override def fail(ex: Throwable): Unit = { - stageLogic.invoke(Failure(ex)) - } - }) + override def complete(): Unit = callback.invoke(Completion) + + override def fail(ex: Throwable): Unit = callback.invoke(Failure(ex)) + + } + + (stageLogic, stageLogic) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index d40fdde007..f231f2638b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -334,7 +334,9 @@ import akka.util.OptionVal override def toString: String = "QueueSink" override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { - val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[Output[T]] with InHandler { + var logicCallback: AsyncCallback[Output[T]] = null + + val stageLogic = new GraphStageLogic(shape) with InHandler with SinkQueueWithCancel[T] { type Received[E] = Try[Option[E]] val maxBuffer = inheritedAttributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max @@ -348,29 +350,22 @@ import akka.util.OptionVal // closed/failure indicators buffer = Buffer(maxBuffer + 1, materializer) setKeepGoing(true) - initCallback(callback.invoke) pull(in) } - override def postStop(): Unit = stopCallback { - case Pull(promise) ⇒ promise.failure(new StreamDetachedException()) - case _ ⇒ //do nothing - } - - private val callback: AsyncCallback[Output[T]] = - getAsyncCallback { - case QueueSink.Pull(pullPromise) ⇒ currentRequest match { - case Some(_) ⇒ - pullPromise.failure(new IllegalStateException("You have to wait for previous future to be resolved to send another request")) - case None ⇒ - if (buffer.isEmpty) currentRequest = Some(pullPromise) - else { - if (buffer.used == maxBuffer) tryPull(in) - sendDownstream(pullPromise) - } - } - case QueueSink.Cancel ⇒ completeStage() + private val callback = getAsyncCallback[Output[T]] { + case QueueSink.Pull(pullPromise) ⇒ currentRequest match { + case Some(_) ⇒ + pullPromise.failure(new IllegalStateException("You have to wait for previous future to be resolved to send another request")) + case None ⇒ + if (buffer.isEmpty) currentRequest = Some(pullPromise) + else { + if (buffer.used == maxBuffer) tryPull(in) + sendDownstream(pullPromise) + } } + case QueueSink.Cancel ⇒ completeStage() + } def sendDownstream(promise: Requested[T]): Unit = { val e = buffer.dequeue() @@ -400,19 +395,22 @@ import akka.util.OptionVal override def onUpstreamFinish(): Unit = enqueueAndNotify(Success(None)) override def onUpstreamFailure(ex: Throwable): Unit = enqueueAndNotify(Failure(ex)) + logicCallback = callback setHandler(in, this) - } - (stageLogic, new SinkQueueWithCancel[T] { + // SinkQueueWithCancel impl override def pull(): Future[Option[T]] = { val p = Promise[Option[T]] - stageLogic.invoke(Pull(p)) + logicCallback.invokeWithFeedback(Pull(p)) + .onFailure { case NonFatal(e) ⇒ p.tryFailure(e) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) p.future } override def cancel(): Unit = { - stageLogic.invoke(QueueSink.Cancel) + logicCallback.invoke(QueueSink.Cancel) } - }) + } + + (stageLogic, stageLogic) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 42d653d4c9..d60e5719ca 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -49,8 +49,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) val downstreamStatus = new AtomicReference[DownstreamStatus](Ok) - final class OutputStreamSourceLogic extends GraphStageLogic(shape) - with CallbackWrapper[(AdapterToStageMessage, Promise[Unit])] { + final class OutputStreamSourceLogic extends GraphStageLogic(shape) { var flush: Option[Promise[Unit]] = None var close: Option[Promise[Unit]] = None @@ -69,7 +68,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { val p = Promise[Unit]() - this.invoke((msg, p)) + upstreamCallback.invoke((msg, p)) p.future } @@ -112,7 +111,6 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration override def preStart(): Unit = { dispatcher = ActorMaterializerHelper.downcast(materializer).system.dispatchers.lookup(dispatcherId) super.preStart() - initCallback(upstreamCallback.invoke) } setHandler(out, new OutHandler { diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index eeb6f1aeee..17ebebfd2f 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -3,26 +3,29 @@ */ package akka.stream.stage +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicReference - -import akka.NotUsed import java.util.concurrent.locks.ReentrantLock +import akka.{ Done, NotUsed } import akka.actor._ -import akka.annotation.ApiMayChange +import akka.annotation.{ ApiMayChange, InternalApi } +import akka.dispatch.ExecutionContexts.sameThreadExecutionContext import akka.japi.function.{ Effect, Procedure } import akka.stream._ -import akka.stream.impl.StreamLayout.AtomicModule -import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource } -import akka.stream.impl.{ EmptyTraversal, LinearTraversalBuilder, ReactiveStreamsCompliance, TraversalBuilder } - -import scala.collection.mutable.ArrayBuffer -import scala.collection.{ immutable, mutable } -import scala.concurrent.duration.FiniteDuration import akka.stream.actor.ActorSubscriberMessage -import akka.stream.scaladsl.{ GenericGraph, GenericGraphWithChangedAttributes } +import akka.stream.impl.{ NotInitialized, ReactiveStreamsCompliance, TraversalBuilder } +import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource } +import akka.stream.scaladsl.GenericGraphWithChangedAttributes import akka.util.OptionVal -import akka.annotation.InternalApi + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.collection.{ immutable, mutable } +import scala.concurrent.{ Future, Promise } +import scala.concurrent.duration.FiniteDuration +import scala.util.Try +import scala.util.control.TailCalls.{ TailRec, done, tailcall } /** * Scala API: A GraphStage represents a reusable graph stream processing stage. @@ -262,10 +265,12 @@ object GraphStageLogic { * cleanup should always be done in `postStop`. */ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) { + import GraphInterpreter._ import GraphStageLogic._ def this(shape: Shape) = this(shape.inlets.size, shape.outlets.size) + /** * INTERNAL API */ @@ -364,14 +369,17 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * nor failure. */ final protected def totallyIgnorantInput: InHandler = TotallyIgnorantInput + /** * Output handler that terminates the stage upon cancellation. */ final protected def eagerTerminateOutput: OutHandler = EagerTerminateOutput + /** * Output handler that does not terminate the stage upon cancellation. */ final protected def ignoreTerminateOutput: OutHandler = IgnoreTerminateOutput + /** * Output handler that terminates the state upon receiving completion if the * given condition holds at that time. The stage fails upon receiving a failure. @@ -735,6 +743,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: onComplete() previous.onUpstreamFinish() } + override def onUpstreamFailure(ex: Throwable): Unit = { setHandler(in, previous) previous.onUpstreamFailure(ex) @@ -965,11 +974,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: doPull: Boolean = false): Unit = { class PassAlongHandler extends InHandler with (() ⇒ Unit) { override def apply(): Unit = tryPull(from) + override def onPush(): Unit = { val elem = grab(from) emit(to, elem, this) } + override def onUpstreamFinish(): Unit = if (doFinish) completeStage() + override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) failStage(ex) } val ph = new PassAlongHandler @@ -984,31 +996,174 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Obtain a callback object that can be used asynchronously to re-enter the * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned - * [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely - * delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and + * [[AsyncCallback]] is safe to be called from other threads. It will in the background thread-safely + * delegate to the passed callback function. I.e. [[invoke()]] will be called by other thread and * the passed handler will be invoked eventually in a thread-safe way by the execution environment. * + * In case stream is not yet materialized [[AsyncCallback]] will buffer events until stream is available. + * + * [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed + * due to stream completion. + * + * Method must be called in thread-safe manner during materialization and in the same thread as materialization + * process. + * * This object can be cached and reused within the same [[GraphStageLogic]]. */ final def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { - new AsyncCallback[T] { - override def invoke(event: T): Unit = - interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) + val result = new ConcurrentAsyncCallback[T](handler) + asyncCallbacksInProgress.add(result) + if (_interpreter != null) result.onStart() + result + } + + /** + * ConcurrentAsyncCallback allows to call [[invoke()]] and [[invokeWithPromise()]] with event attribute. + * This event will be sent to the stream and the corresponding handler will be called with this attribute in thread-safe manner. + * + * State of this object can be changed both "internally" by the owning GraphStage or by the "external world" (e.g. other threads). + * Specifically, calls to this class can be made: + * * From the owning [[GraphStage]], to [[onStart]] - when materialization is finished and to [[onStop()]] - + * because the stage is about to stop or fail. + * * "Real world" calls [[invoke()]] and [[invokeWithFeedback()]]. These methods have synchronization + * with class state that reflects the stream state + * + * onStart sends all events that were buffered while stream was materializing. + * In case "Real world" added more events while initializing, onStart checks for more events in buffer when exiting and + * resend new events + * + * Once class is in `Initialized` state - all "Real world" calls of [[invoke()]] and [[invokeWithFeedback()]] are running + * as is - without blocking each other. + * + * [[GraphStage]] is called [[onStop()]] when stream is wrapping down. onStop fails all futures for events that have not yet processed + * [[onStop()]] puts class in `Completed` state + * "Real world" calls of [[invokeWithFeedback()]] always return failed promises for `Completed` state + */ + private class ConcurrentAsyncCallback[T](handler: T ⇒ Unit) extends AsyncCallback[T] { + + sealed trait State + // waiting for materialization completion + private case class Pending(pendingEvents: List[Event]) extends State + // GraphStage sending all events to stream + private case object Initializing extends State + // stream is initialized and so no threads can just send events without any synchronization overhead + private case object Initialized extends State + // stage has been shut down, either regularly or it failed + private case object Completed extends State + // Event with feedback promise + private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]]) + + val waitingForProcessing = ConcurrentHashMap.newKeySet[Promise[_]]() + + private[this] val currentState = new AtomicReference[State](Pending(Nil)) + + // is called from the owning [[GraphStage]] + @tailrec + final private[stage] def onStart(): Unit = { + (currentState.getAndSet(Initializing): @unchecked) match { + case Pending(l) ⇒ l.reverse.foreach(ack ⇒ { + onAsyncInput(ack.e) + }) + } + if (!currentState.compareAndSet(Initializing, Initialized)) { + (currentState.get: @unchecked) match { + case Pending(_) ⇒ onStart() + case Completed ⇒ () //wonder if this is possible + } + } + } + + // is called from the owning [[GraphStage]] + final private[stage] def onStop(): Unit = { + currentState.set(Completed) + val iterator = waitingForProcessing.iterator() + lazy val detachedException = new StreamDetachedException() + while (iterator.hasNext) { iterator.next().tryFailure(detachedException) } + waitingForProcessing.clear() + } + + private def onAsyncInput(event: T) = interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) + + private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = { + onAsyncInput(event) + currentState.get() match { + case Completed ⇒ failPromiseOnComplete(promise) + case _ ⇒ promise + } + } + + // external call + override def invokeWithFeedback(event: T): Future[Done] = { + val promise: Promise[Done] = Promise[Done]() + promise.future.andThen { case _ ⇒ waitingForProcessing.remove(promise) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) + waitingForProcessing.add(promise) + invokeWithPromise(event, promise).future + } + + private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = { + currentState.get() match { + // not started yet + case list @ Pending(_) ⇒ + if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents))) + invokeWithPromise(event, promise) // atomicity is failed - try again + else promise + // started - can just send message to stream + case Initialized ⇒ sendEvent(event, promise) + // initializing is in progress in another thread (initializing thread is managed by Akka) + case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) { + (currentState.get(): @unchecked) match { + case Pending(_) ⇒ invokeWithPromise(event, promise) // atomicity is failed - try again + case Initialized ⇒ sendEvent(event, promise) + } + } else promise + // fail promise as stream is completed + case Completed ⇒ failPromiseOnComplete(promise) + } + } + + private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { + waitingForProcessing.remove(promise) + promise.tryFailure(new StreamDetachedException()) + promise + } + + //external call + override def invoke(event: T): Unit = { + @tailrec + def internalInvoke(event: T): Unit = currentState.get() match { + // started - can just send message to stream + case Initialized ⇒ onAsyncInput(event) + // not started yet + case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) + // initializing is in progress in another thread (initializing thread is managed by akka) + case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal.None) :: Nil))) { + (currentState.get(): @unchecked) match { + case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) + case Initialized ⇒ onAsyncInput(event) + } + } + case Completed ⇒ // do nothing here as stream is completed + } + internalInvoke(event) } } /** * Java API: Obtain a callback object that can be used asynchronously to re-enter the * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned - * [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely - * delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and + * [[AsyncCallback]] is safe to be called from other threads. It will in the background thread-safely + * delegate to the passed callback function. I.e. [[invoke()]] will be called by other thread and * the passed handler will be invoked eventually in a thread-safe way by the execution environment. * + * [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed due to stream completion. + * * This object can be cached and reused within the same [[GraphStageLogic]]. */ final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] = getAsyncCallback(handler.apply) + private val asyncCallbacksInProgress = mutable.HashSet[ConcurrentAsyncCallback[_]]() + private var _stageActor: StageActor = _ final def stageActor: StageActor = _stageActor match { case null ⇒ throw StageActorRefNotInitializedException() @@ -1048,7 +1203,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // Internal hooks to avoid reliance on user calling super in preStart /** INTERNAL API */ - protected[stream] def beforePreStart(): Unit = () + protected[stream] def beforePreStart(): Unit = { + asyncCallbacksInProgress.foreach(_.onStart()) + } // Internal hooks to avoid reliance on user calling super in postStop /** INTERNAL API */ @@ -1057,6 +1214,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: _stageActor.stop() _stageActor = null } + asyncCallbacksInProgress.foreach(_.onStop()) } /** @@ -1230,8 +1388,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * An asynchronous callback holder that is attached to a [[GraphStageLogic]]. - * Invoking [[AsyncCallback#invoke]] will eventually lead to the registered handler + * Initializing [[AsyncCallback#invoke]] will eventually lead to the registered handler * being called. + * + * This holder has the same lifecycle as a stream and cannot be used before + * materialization is done. + * + * Typical use cases are exchanging messages between stream and substreams or invoking from external world sending + * event to a stream */ trait AsyncCallback[T] { /** @@ -1239,6 +1403,13 @@ trait AsyncCallback[T] { * may be invoked from external execution contexts. */ def invoke(t: T): Unit + /** + * Dispatch an asynchronous notification. + * This method is thread-safe and may be invoked from external execution contexts. + * Promise in `HasCallbackPromise` will fail if stream is already closed or closed before + * being able to process the event + */ + def invokeWithFeedback(t: T): Future[Done] } abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) { @@ -1414,53 +1585,3 @@ abstract class AbstractOutHandler extends OutHandler * (completing when upstream completes, failing when upstream fails, completing when downstream cancels). */ abstract class AbstractInOutHandler extends InHandler with OutHandler - -/** - * INTERNAL API - * This trait wraps callback for `GraphStage` stage instances and handle gracefully cases when stage is - * not yet initialized or already finished. - * - * While `GraphStage` has not initialized it adds all requests to list. - * As soon as `GraphStage` is started it stops collecting requests (pointing to real callback - * function) and run all the callbacks from the list - * - * Supposed to be used by GraphStages that share call back to outer world - */ -private[akka] trait CallbackWrapper[T] extends AsyncCallback[T] { - private trait CallbackState - private case class NotInitialized(list: List[T]) extends CallbackState - private case class Initialized(f: T ⇒ Unit) extends CallbackState - private case class Stopped(f: T ⇒ Unit) extends CallbackState - - /* - * To preserve message order when switching between not initialized / initialized states - * lock is used. Case is similar to RepointableActorRef - */ - private[this] final val lock = new ReentrantLock - - private[this] val callbackState = new AtomicReference[CallbackState](NotInitialized(Nil)) - - def stopCallback(f: T ⇒ Unit): Unit = locked { - callbackState.set(Stopped(f)) - } - - def initCallback(f: T ⇒ Unit): Unit = locked { - val list = (callbackState.getAndSet(Initialized(f)): @unchecked) match { - case NotInitialized(l) ⇒ l - } - list.reverse.foreach(f) - } - - override def invoke(arg: T): Unit = locked { - callbackState.get() match { - case Initialized(cb) ⇒ cb(arg) - case list @ NotInitialized(l) ⇒ callbackState.compareAndSet(list, NotInitialized(arg :: l)) - case Stopped(cb) ⇒ cb(arg) - } - } - - private[this] def locked(body: ⇒ Unit): Unit = { - lock.lock() - try body finally lock.unlock() - } -}