+str AsyncCallbacks lost on finished stage by error

This commit is contained in:
Alexander Golubev 2017-06-17 22:51:34 +03:00
parent c05a3e0e26
commit 9b43ce71ba
8 changed files with 287 additions and 157 deletions

View file

@ -73,7 +73,7 @@ private[remote] object InboundControlJunction {
def notify(inboundEnvelope: InboundEnvelope): Unit def notify(inboundEnvelope: InboundEnvelope): Unit
} }
// messages for the CallbackWrapper // messages for the stream callback
private[InboundControlJunction] sealed trait CallbackMessage private[InboundControlJunction] sealed trait CallbackMessage
private[InboundControlJunction] final case class Attach(observer: ControlMessageObserver, done: Promise[Done]) private[InboundControlJunction] final case class Attach(observer: ControlMessageObserver, done: Promise[Done])
extends CallbackMessage extends CallbackMessage
@ -93,8 +93,7 @@ private[remote] class InboundControlJunction
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val stoppedPromise = Promise[Done]() val stoppedPromise = Promise[Done]()
// FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way val logic = new GraphStageLogic(shape) with InHandler with OutHandler with ControlMessageSubject {
val logic = new GraphStageLogic(shape) with CallbackWrapper[CallbackMessage] with InHandler with OutHandler {
private var observers: Vector[ControlMessageObserver] = Vector.empty private var observers: Vector[ControlMessageObserver] = Vector.empty
@ -106,10 +105,6 @@ private[remote] class InboundControlJunction
observers = observers.filterNot(_ == observer) observers = observers.filterNot(_ == observer)
} }
override def preStart(): Unit = {
initCallback(callback.invoke)
}
override def postStop(): Unit = stoppedPromise.success(Done) override def postStop(): Unit = stoppedPromise.success(Done)
// InHandler // InHandler
@ -127,24 +122,22 @@ private[remote] class InboundControlJunction
override def onPull(): Unit = pull(in) override def onPull(): Unit = pull(in)
setHandlers(in, out, this) setHandlers(in, out, this)
}
// materialized value // ControlMessageSubject impl
val controlSubject: ControlMessageSubject = new ControlMessageSubject {
override def attach(observer: ControlMessageObserver): Future[Done] = { override def attach(observer: ControlMessageObserver): Future[Done] = {
val p = Promise[Done]() val p = Promise[Done]()
logic.invoke(Attach(observer, p)) callback.invoke(Attach(observer, p))
p.future p.future
} }
override def detach(observer: ControlMessageObserver): Unit = override def detach(observer: ControlMessageObserver): Unit =
logic.invoke(Dettach(observer)) callback.invoke(Dettach(observer))
override def stopped: Future[Done] = override def stopped: Future[Done] =
stoppedPromise.future stoppedPromise.future
} }
(logic, controlSubject) (logic, logic)
} }
} }
@ -169,18 +162,14 @@ private[remote] class OutboundControlJunction(
override val shape: FlowShape[OutboundEnvelope, OutboundEnvelope] = FlowShape(in, out) override val shape: FlowShape[OutboundEnvelope, OutboundEnvelope] = FlowShape(in, out)
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
// FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way
val logic = new GraphStageLogic(shape) with CallbackWrapper[ControlMessage] with InHandler with OutHandler with StageLogging { val logic = new GraphStageLogic(shape) with InHandler with OutHandler with StageLogging with OutboundControlIngress {
import OutboundControlJunction._ import OutboundControlJunction._
private val sendControlMessageCallback = getAsyncCallback[ControlMessage](internalSendControlMessage) val sendControlMessageCallback = getAsyncCallback[ControlMessage](internalSendControlMessage)
private val maxControlMessageBufferSize: Int = outboundContext.settings.Advanced.OutboundControlQueueSize private val maxControlMessageBufferSize: Int = outboundContext.settings.Advanced.OutboundControlQueueSize
private val buffer = new ArrayDeque[OutboundEnvelope] private val buffer = new ArrayDeque[OutboundEnvelope]
override def preStart(): Unit = {
initCallback(sendControlMessageCallback.invoke)
}
// InHandler // InHandler
override def onPush(): Unit = { override def onPush(): Unit = {
if (buffer.isEmpty && isAvailable(out)) if (buffer.isEmpty && isAvailable(out))
@ -212,16 +201,13 @@ private[remote] class OutboundControlJunction(
outboundEnvelopePool.acquire().init( outboundEnvelopePool.acquire().init(
recipient = OptionVal.None, message = message, sender = OptionVal.None) recipient = OptionVal.None, message = message, sender = OptionVal.None)
override def sendControlMessage(message: ControlMessage): Unit =
sendControlMessageCallback.invoke(message)
setHandlers(in, out, this) setHandlers(in, out, this)
} }
// materialized value (logic, logic)
val outboundControlIngress = new OutboundControlIngress {
override def sendControlMessage(message: ControlMessage): Unit =
logic.invoke(message)
}
(logic, outboundControlIngress)
} }
} }

View file

@ -12,9 +12,11 @@ import akka.stream.testkit.{ GraphStageMessages, StreamSpec, TestSourceStage, Te
import akka.stream.testkit.scaladsl.TestSink import akka.stream.testkit.scaladsl.TestSink
import akka.stream.testkit.Utils._ import akka.stream.testkit.Utils._
import akka.testkit.TestProbe import akka.testkit.TestProbe
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.concurrent._ import scala.concurrent._
import akka.Done
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink
import org.scalatest.time.Span import org.scalatest.time.Span
class QueueSourceSpec extends StreamSpec { class QueueSourceSpec extends StreamSpec {
@ -184,7 +186,7 @@ class QueueSourceSpec extends StreamSpec {
expectMsgClass(classOf[Status.Failure]) expectMsgClass(classOf[Status.Failure])
} }
"return false when elemen was not added to buffer" in assertAllStagesStopped { "return false when element was not added to buffer" in assertAllStagesStopped {
val s = TestSubscriber.manualProbe[Int]() val s = TestSubscriber.manualProbe[Int]()
val queue = Source.queue(1, OverflowStrategy.dropNew).to(Sink.fromSubscriber(s)).run() val queue = Source.queue(1, OverflowStrategy.dropNew).to(Sink.fromSubscriber(s)).run()
val sub = s.expectSubscription val sub = s.expectSubscription
@ -300,6 +302,15 @@ class QueueSourceSpec extends StreamSpec {
.expectComplete() .expectComplete()
source.watchCompletion().futureValue should ===(Done) source.watchCompletion().futureValue should ===(Done)
} }
"some elements not yet delivered to stage" in {
val (queue, probe) =
Source.queue[Unit](10, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run()
intercept[StreamDetachedException] {
Await.result(
(1 to 15).map(_ queue.offer(())).last, 3.seconds)
}
}
} }
"fail the stream" when { "fail the stream" when {

View file

@ -1,2 +1,20 @@
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.GzipCompressor.this") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.GzipCompressor.this")
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.DeflateCompressor.this") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.compression.DeflateCompressor.this")
# #23111 AsyncCallbacks to just-finishing stages can be lost
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Offer")
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Completion$")
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSink$Pull")
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSink$Cancel$")
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.QueueSink$Output")
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.QueueSource$Failure")
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.QueueSource$Input")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.stage.AsyncCallback.invokeWithFeedback")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Stopped")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$NotInitialized")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Stopped$")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Initialized")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$Initialized$")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$NotInitialized$")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper$CallbackState")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper")

View file

@ -3,26 +3,29 @@
*/ */
package akka.stream.impl package akka.stream.impl
import java.util.concurrent.CompletionStage
import akka.Done
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts.sameThreadExecutionContext
import akka.stream.OverflowStrategies._ import akka.stream.OverflowStrategies._
import akka.stream._ import akka.stream._
import akka.stream.stage._ import akka.stream.stage._
import akka.stream.scaladsl.SourceQueueWithComplete import akka.stream.scaladsl.SourceQueueWithComplete
import akka.Done
import java.util.concurrent.CompletionStage
import akka.annotation.InternalApi
import scala.concurrent.{ Future, Promise }
import scala.compat.java8.FutureConverters._ import scala.compat.java8.FutureConverters._
import scala.concurrent.{ Future, Promise }
import scala.util.control.NonFatal
/** /**
* INTERNAL API * INTERNAL API
*/ */
@InternalApi private[akka] object QueueSource { @InternalApi private[akka] object QueueSource {
sealed trait Input[+T] sealed trait Input[+T]
final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T] final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T]
case object Completion extends Input[Nothing] case object Completion extends Input[Nothing]
final case class Failure(ex: Throwable) extends Input[Nothing] final case class Failure(ex: Throwable) extends Input[Nothing]
} }
/** /**
@ -36,22 +39,18 @@ import scala.compat.java8.FutureConverters._
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val completion = Promise[Done] val completion = Promise[Done]
val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[Input[T]] with OutHandler {
val stageLogic = new GraphStageLogic(shape) with OutHandler with SourceQueueWithComplete[T] {
var buffer: Buffer[T] = _ var buffer: Buffer[T] = _
var pendingOffer: Option[Offer[T]] = None var pendingOffer: Option[Offer[T]] = None
var terminating = false var terminating = false
override def preStart(): Unit = { override def preStart(): Unit = {
if (maxBuffer > 0) buffer = Buffer(maxBuffer, materializer) if (maxBuffer > 0) buffer = Buffer(maxBuffer, materializer)
initCallback(callback.invoke)
} }
override def postStop(): Unit = { override def postStop(): Unit = {
val exception = new StreamDetachedException() val exception = new StreamDetachedException()
completion.tryFailure(exception) completion.tryFailure(exception)
stopCallback {
case Offer(elem, promise) promise.failure(exception)
case _ // ignore
}
} }
private def enqueueAndSuccess(offer: Offer[T]): Unit = { private def enqueueAndSuccess(offer: Offer[T]): Unit = {
@ -75,7 +74,7 @@ import scala.compat.java8.FutureConverters._
case DropNew case DropNew
offer.promise.success(QueueOfferResult.Dropped) offer.promise.success(QueueOfferResult.Dropped)
case Fail case Fail
val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") val bufferOverflowException = BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!")
offer.promise.success(QueueOfferResult.Failure(bufferOverflowException)) offer.promise.success(QueueOfferResult.Failure(bufferOverflowException))
completion.failure(bufferOverflowException) completion.failure(bufferOverflowException)
failStage(bufferOverflowException) failStage(bufferOverflowException)
@ -89,8 +88,7 @@ import scala.compat.java8.FutureConverters._
} }
} }
private val callback: AsyncCallback[Input[T]] = getAsyncCallback { private val callback = getAsyncCallback[Input[T]] {
case offer @ Offer(elem, promise) case offer @ Offer(elem, promise)
if (maxBuffer != 0) { if (maxBuffer != 0) {
bufferElem(offer) bufferElem(offer)
@ -107,7 +105,7 @@ import scala.compat.java8.FutureConverters._
case DropTail | DropNew case DropTail | DropNew
promise.success(QueueOfferResult.Dropped) promise.success(QueueOfferResult.Dropped)
case Fail case Fail
val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") val bufferOverflowException = BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!")
promise.success(QueueOfferResult.Failure(bufferOverflowException)) promise.success(QueueOfferResult.Failure(bufferOverflowException))
completion.failure(bufferOverflowException) completion.failure(bufferOverflowException)
failStage(bufferOverflowException) failStage(bufferOverflowException)
@ -131,7 +129,7 @@ import scala.compat.java8.FutureConverters._
override def onDownstreamFinish(): Unit = { override def onDownstreamFinish(): Unit = {
pendingOffer match { pendingOffer match {
case Some(Offer(elem, promise)) case Some(Offer(_, promise))
promise.success(QueueOfferResult.QueueClosed) promise.success(QueueOfferResult.QueueClosed)
pendingOffer = None pendingOffer = None
case None // do nothing case None // do nothing
@ -167,22 +165,22 @@ import scala.compat.java8.FutureConverters._
} }
} }
} }
}
(stageLogic, new SourceQueueWithComplete[T] { // SourceQueueWithComplete impl
override def watchCompletion() = completion.future override def watchCompletion() = completion.future
override def offer(element: T): Future[QueueOfferResult] = { override def offer(element: T): Future[QueueOfferResult] = {
val p = Promise[QueueOfferResult] val p = Promise[QueueOfferResult]
stageLogic.invoke(Offer(element, p)) callback.invokeWithFeedback(Offer(element, p))
.onFailure { case NonFatal(e) p.tryFailure(e) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
p.future p.future
} }
override def complete(): Unit = { override def complete(): Unit = callback.invoke(Completion)
stageLogic.invoke(Completion)
} override def fail(ex: Throwable): Unit = callback.invoke(Failure(ex))
override def fail(ex: Throwable): Unit = {
stageLogic.invoke(Failure(ex)) }
}
}) (stageLogic, stageLogic)
} }
} }

View file

@ -334,7 +334,9 @@ import akka.util.OptionVal
override def toString: String = "QueueSink" override def toString: String = "QueueSink"
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[Output[T]] with InHandler { var logicCallback: AsyncCallback[Output[T]] = null
val stageLogic = new GraphStageLogic(shape) with InHandler with SinkQueueWithCancel[T] {
type Received[E] = Try[Option[E]] type Received[E] = Try[Option[E]]
val maxBuffer = inheritedAttributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max val maxBuffer = inheritedAttributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max
@ -348,29 +350,22 @@ import akka.util.OptionVal
// closed/failure indicators // closed/failure indicators
buffer = Buffer(maxBuffer + 1, materializer) buffer = Buffer(maxBuffer + 1, materializer)
setKeepGoing(true) setKeepGoing(true)
initCallback(callback.invoke)
pull(in) pull(in)
} }
override def postStop(): Unit = stopCallback { private val callback = getAsyncCallback[Output[T]] {
case Pull(promise) promise.failure(new StreamDetachedException()) case QueueSink.Pull(pullPromise) currentRequest match {
case _ //do nothing case Some(_)
} pullPromise.failure(new IllegalStateException("You have to wait for previous future to be resolved to send another request"))
case None
private val callback: AsyncCallback[Output[T]] = if (buffer.isEmpty) currentRequest = Some(pullPromise)
getAsyncCallback { else {
case QueueSink.Pull(pullPromise) currentRequest match { if (buffer.used == maxBuffer) tryPull(in)
case Some(_) sendDownstream(pullPromise)
pullPromise.failure(new IllegalStateException("You have to wait for previous future to be resolved to send another request")) }
case None
if (buffer.isEmpty) currentRequest = Some(pullPromise)
else {
if (buffer.used == maxBuffer) tryPull(in)
sendDownstream(pullPromise)
}
}
case QueueSink.Cancel completeStage()
} }
case QueueSink.Cancel completeStage()
}
def sendDownstream(promise: Requested[T]): Unit = { def sendDownstream(promise: Requested[T]): Unit = {
val e = buffer.dequeue() val e = buffer.dequeue()
@ -400,19 +395,22 @@ import akka.util.OptionVal
override def onUpstreamFinish(): Unit = enqueueAndNotify(Success(None)) override def onUpstreamFinish(): Unit = enqueueAndNotify(Success(None))
override def onUpstreamFailure(ex: Throwable): Unit = enqueueAndNotify(Failure(ex)) override def onUpstreamFailure(ex: Throwable): Unit = enqueueAndNotify(Failure(ex))
logicCallback = callback
setHandler(in, this) setHandler(in, this)
}
(stageLogic, new SinkQueueWithCancel[T] { // SinkQueueWithCancel impl
override def pull(): Future[Option[T]] = { override def pull(): Future[Option[T]] = {
val p = Promise[Option[T]] val p = Promise[Option[T]]
stageLogic.invoke(Pull(p)) logicCallback.invokeWithFeedback(Pull(p))
.onFailure { case NonFatal(e) p.tryFailure(e) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
p.future p.future
} }
override def cancel(): Unit = { override def cancel(): Unit = {
stageLogic.invoke(QueueSink.Cancel) logicCallback.invoke(QueueSink.Cancel)
} }
}) }
(stageLogic, stageLogic)
} }
} }

View file

@ -49,8 +49,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer)
val downstreamStatus = new AtomicReference[DownstreamStatus](Ok) val downstreamStatus = new AtomicReference[DownstreamStatus](Ok)
final class OutputStreamSourceLogic extends GraphStageLogic(shape) final class OutputStreamSourceLogic extends GraphStageLogic(shape) {
with CallbackWrapper[(AdapterToStageMessage, Promise[Unit])] {
var flush: Option[Promise[Unit]] = None var flush: Option[Promise[Unit]] = None
var close: Option[Promise[Unit]] = None var close: Option[Promise[Unit]] = None
@ -69,7 +68,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { def wakeUp(msg: AdapterToStageMessage): Future[Unit] = {
val p = Promise[Unit]() val p = Promise[Unit]()
this.invoke((msg, p)) upstreamCallback.invoke((msg, p))
p.future p.future
} }
@ -112,7 +111,6 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
override def preStart(): Unit = { override def preStart(): Unit = {
dispatcher = ActorMaterializerHelper.downcast(materializer).system.dispatchers.lookup(dispatcherId) dispatcher = ActorMaterializerHelper.downcast(materializer).system.dispatchers.lookup(dispatcherId)
super.preStart() super.preStart()
initCallback(upstreamCallback.invoke)
} }
setHandler(out, new OutHandler { setHandler(out, new OutHandler {

View file

@ -3,26 +3,29 @@
*/ */
package akka.stream.stage package akka.stream.stage
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.atomic.AtomicReference
import akka.NotUsed
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import akka.{ Done, NotUsed }
import akka.actor._ import akka.actor._
import akka.annotation.ApiMayChange import akka.annotation.{ ApiMayChange, InternalApi }
import akka.dispatch.ExecutionContexts.sameThreadExecutionContext
import akka.japi.function.{ Effect, Procedure } import akka.japi.function.{ Effect, Procedure }
import akka.stream._ import akka.stream._
import akka.stream.impl.StreamLayout.AtomicModule
import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource }
import akka.stream.impl.{ EmptyTraversal, LinearTraversalBuilder, ReactiveStreamsCompliance, TraversalBuilder }
import scala.collection.mutable.ArrayBuffer
import scala.collection.{ immutable, mutable }
import scala.concurrent.duration.FiniteDuration
import akka.stream.actor.ActorSubscriberMessage import akka.stream.actor.ActorSubscriberMessage
import akka.stream.scaladsl.{ GenericGraph, GenericGraphWithChangedAttributes } import akka.stream.impl.{ NotInitialized, ReactiveStreamsCompliance, TraversalBuilder }
import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource }
import akka.stream.scaladsl.GenericGraphWithChangedAttributes
import akka.util.OptionVal import akka.util.OptionVal
import akka.annotation.InternalApi
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.{ immutable, mutable }
import scala.concurrent.{ Future, Promise }
import scala.concurrent.duration.FiniteDuration
import scala.util.Try
import scala.util.control.TailCalls.{ TailRec, done, tailcall }
/** /**
* Scala API: A GraphStage represents a reusable graph stream processing stage. * Scala API: A GraphStage represents a reusable graph stream processing stage.
@ -262,10 +265,12 @@ object GraphStageLogic {
* cleanup should always be done in `postStop`. * cleanup should always be done in `postStop`.
*/ */
abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) { abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) {
import GraphInterpreter._ import GraphInterpreter._
import GraphStageLogic._ import GraphStageLogic._
def this(shape: Shape) = this(shape.inlets.size, shape.outlets.size) def this(shape: Shape) = this(shape.inlets.size, shape.outlets.size)
/** /**
* INTERNAL API * INTERNAL API
*/ */
@ -364,14 +369,17 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
* nor failure. * nor failure.
*/ */
final protected def totallyIgnorantInput: InHandler = TotallyIgnorantInput final protected def totallyIgnorantInput: InHandler = TotallyIgnorantInput
/** /**
* Output handler that terminates the stage upon cancellation. * Output handler that terminates the stage upon cancellation.
*/ */
final protected def eagerTerminateOutput: OutHandler = EagerTerminateOutput final protected def eagerTerminateOutput: OutHandler = EagerTerminateOutput
/** /**
* Output handler that does not terminate the stage upon cancellation. * Output handler that does not terminate the stage upon cancellation.
*/ */
final protected def ignoreTerminateOutput: OutHandler = IgnoreTerminateOutput final protected def ignoreTerminateOutput: OutHandler = IgnoreTerminateOutput
/** /**
* Output handler that terminates the state upon receiving completion if the * Output handler that terminates the state upon receiving completion if the
* given condition holds at that time. The stage fails upon receiving a failure. * given condition holds at that time. The stage fails upon receiving a failure.
@ -735,6 +743,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
onComplete() onComplete()
previous.onUpstreamFinish() previous.onUpstreamFinish()
} }
override def onUpstreamFailure(ex: Throwable): Unit = { override def onUpstreamFailure(ex: Throwable): Unit = {
setHandler(in, previous) setHandler(in, previous)
previous.onUpstreamFailure(ex) previous.onUpstreamFailure(ex)
@ -965,11 +974,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
doPull: Boolean = false): Unit = { doPull: Boolean = false): Unit = {
class PassAlongHandler extends InHandler with (() Unit) { class PassAlongHandler extends InHandler with (() Unit) {
override def apply(): Unit = tryPull(from) override def apply(): Unit = tryPull(from)
override def onPush(): Unit = { override def onPush(): Unit = {
val elem = grab(from) val elem = grab(from)
emit(to, elem, this) emit(to, elem, this)
} }
override def onUpstreamFinish(): Unit = if (doFinish) completeStage() override def onUpstreamFinish(): Unit = if (doFinish) completeStage()
override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) failStage(ex) override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) failStage(ex)
} }
val ph = new PassAlongHandler val ph = new PassAlongHandler
@ -984,31 +996,174 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
/** /**
* Obtain a callback object that can be used asynchronously to re-enter the * Obtain a callback object that can be used asynchronously to re-enter the
* current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned
* [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely * [[AsyncCallback]] is safe to be called from other threads. It will in the background thread-safely
* delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and * delegate to the passed callback function. I.e. [[invoke()]] will be called by other thread and
* the passed handler will be invoked eventually in a thread-safe way by the execution environment. * the passed handler will be invoked eventually in a thread-safe way by the execution environment.
* *
* In case stream is not yet materialized [[AsyncCallback]] will buffer events until stream is available.
*
* [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed
* due to stream completion.
*
* Method must be called in thread-safe manner during materialization and in the same thread as materialization
* process.
*
* This object can be cached and reused within the same [[GraphStageLogic]]. * This object can be cached and reused within the same [[GraphStageLogic]].
*/ */
final def getAsyncCallback[T](handler: T Unit): AsyncCallback[T] = { final def getAsyncCallback[T](handler: T Unit): AsyncCallback[T] = {
new AsyncCallback[T] { val result = new ConcurrentAsyncCallback[T](handler)
override def invoke(event: T): Unit = asyncCallbacksInProgress.add(result)
interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any Unit]) if (_interpreter != null) result.onStart()
result
}
/**
* ConcurrentAsyncCallback allows to call [[invoke()]] and [[invokeWithPromise()]] with event attribute.
* This event will be sent to the stream and the corresponding handler will be called with this attribute in thread-safe manner.
*
* State of this object can be changed both "internally" by the owning GraphStage or by the "external world" (e.g. other threads).
* Specifically, calls to this class can be made:
* * From the owning [[GraphStage]], to [[onStart]] - when materialization is finished and to [[onStop()]] -
* because the stage is about to stop or fail.
* * "Real world" calls [[invoke()]] and [[invokeWithFeedback()]]. These methods have synchronization
* with class state that reflects the stream state
*
* onStart sends all events that were buffered while stream was materializing.
* In case "Real world" added more events while initializing, onStart checks for more events in buffer when exiting and
* resend new events
*
* Once class is in `Initialized` state - all "Real world" calls of [[invoke()]] and [[invokeWithFeedback()]] are running
* as is - without blocking each other.
*
* [[GraphStage]] is called [[onStop()]] when stream is wrapping down. onStop fails all futures for events that have not yet processed
* [[onStop()]] puts class in `Completed` state
* "Real world" calls of [[invokeWithFeedback()]] always return failed promises for `Completed` state
*/
private class ConcurrentAsyncCallback[T](handler: T Unit) extends AsyncCallback[T] {
sealed trait State
// waiting for materialization completion
private case class Pending(pendingEvents: List[Event]) extends State
// GraphStage sending all events to stream
private case object Initializing extends State
// stream is initialized and so no threads can just send events without any synchronization overhead
private case object Initialized extends State
// stage has been shut down, either regularly or it failed
private case object Completed extends State
// Event with feedback promise
private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]])
val waitingForProcessing = ConcurrentHashMap.newKeySet[Promise[_]]()
private[this] val currentState = new AtomicReference[State](Pending(Nil))
// is called from the owning [[GraphStage]]
@tailrec
final private[stage] def onStart(): Unit = {
(currentState.getAndSet(Initializing): @unchecked) match {
case Pending(l) l.reverse.foreach(ack {
onAsyncInput(ack.e)
})
}
if (!currentState.compareAndSet(Initializing, Initialized)) {
(currentState.get: @unchecked) match {
case Pending(_) onStart()
case Completed () //wonder if this is possible
}
}
}
// is called from the owning [[GraphStage]]
final private[stage] def onStop(): Unit = {
currentState.set(Completed)
val iterator = waitingForProcessing.iterator()
lazy val detachedException = new StreamDetachedException()
while (iterator.hasNext) { iterator.next().tryFailure(detachedException) }
waitingForProcessing.clear()
}
private def onAsyncInput(event: T) = interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any Unit])
private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = {
onAsyncInput(event)
currentState.get() match {
case Completed failPromiseOnComplete(promise)
case _ promise
}
}
// external call
override def invokeWithFeedback(event: T): Future[Done] = {
val promise: Promise[Done] = Promise[Done]()
promise.future.andThen { case _ waitingForProcessing.remove(promise) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
waitingForProcessing.add(promise)
invokeWithPromise(event, promise).future
}
private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = {
currentState.get() match {
// not started yet
case list @ Pending(_)
if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents)))
invokeWithPromise(event, promise) // atomicity is failed - try again
else promise
// started - can just send message to stream
case Initialized sendEvent(event, promise)
// initializing is in progress in another thread (initializing thread is managed by Akka)
case Initializing if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) {
(currentState.get(): @unchecked) match {
case Pending(_) invokeWithPromise(event, promise) // atomicity is failed - try again
case Initialized sendEvent(event, promise)
}
} else promise
// fail promise as stream is completed
case Completed failPromiseOnComplete(promise)
}
}
private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = {
waitingForProcessing.remove(promise)
promise.tryFailure(new StreamDetachedException())
promise
}
//external call
override def invoke(event: T): Unit = {
@tailrec
def internalInvoke(event: T): Unit = currentState.get() match {
// started - can just send message to stream
case Initialized onAsyncInput(event)
// not started yet
case list @ Pending(l) if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event)
// initializing is in progress in another thread (initializing thread is managed by akka)
case Initializing if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal.None) :: Nil))) {
(currentState.get(): @unchecked) match {
case list @ Pending(l) if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event)
case Initialized onAsyncInput(event)
}
}
case Completed // do nothing here as stream is completed
}
internalInvoke(event)
} }
} }
/** /**
* Java API: Obtain a callback object that can be used asynchronously to re-enter the * Java API: Obtain a callback object that can be used asynchronously to re-enter the
* current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned
* [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely * [[AsyncCallback]] is safe to be called from other threads. It will in the background thread-safely
* delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and * delegate to the passed callback function. I.e. [[invoke()]] will be called by other thread and
* the passed handler will be invoked eventually in a thread-safe way by the execution environment. * the passed handler will be invoked eventually in a thread-safe way by the execution environment.
* *
* [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed due to stream completion.
*
* This object can be cached and reused within the same [[GraphStageLogic]]. * This object can be cached and reused within the same [[GraphStageLogic]].
*/ */
final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] = final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] =
getAsyncCallback(handler.apply) getAsyncCallback(handler.apply)
private val asyncCallbacksInProgress = mutable.HashSet[ConcurrentAsyncCallback[_]]()
private var _stageActor: StageActor = _ private var _stageActor: StageActor = _
final def stageActor: StageActor = _stageActor match { final def stageActor: StageActor = _stageActor match {
case null throw StageActorRefNotInitializedException() case null throw StageActorRefNotInitializedException()
@ -1048,7 +1203,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
// Internal hooks to avoid reliance on user calling super in preStart // Internal hooks to avoid reliance on user calling super in preStart
/** INTERNAL API */ /** INTERNAL API */
protected[stream] def beforePreStart(): Unit = () protected[stream] def beforePreStart(): Unit = {
asyncCallbacksInProgress.foreach(_.onStart())
}
// Internal hooks to avoid reliance on user calling super in postStop // Internal hooks to avoid reliance on user calling super in postStop
/** INTERNAL API */ /** INTERNAL API */
@ -1057,6 +1214,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
_stageActor.stop() _stageActor.stop()
_stageActor = null _stageActor = null
} }
asyncCallbacksInProgress.foreach(_.onStop())
} }
/** /**
@ -1230,8 +1388,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
/** /**
* An asynchronous callback holder that is attached to a [[GraphStageLogic]]. * An asynchronous callback holder that is attached to a [[GraphStageLogic]].
* Invoking [[AsyncCallback#invoke]] will eventually lead to the registered handler * Initializing [[AsyncCallback#invoke]] will eventually lead to the registered handler
* being called. * being called.
*
* This holder has the same lifecycle as a stream and cannot be used before
* materialization is done.
*
* Typical use cases are exchanging messages between stream and substreams or invoking from external world sending
* event to a stream
*/ */
trait AsyncCallback[T] { trait AsyncCallback[T] {
/** /**
@ -1239,6 +1403,13 @@ trait AsyncCallback[T] {
* may be invoked from external execution contexts. * may be invoked from external execution contexts.
*/ */
def invoke(t: T): Unit def invoke(t: T): Unit
/**
* Dispatch an asynchronous notification.
* This method is thread-safe and may be invoked from external execution contexts.
* Promise in `HasCallbackPromise` will fail if stream is already closed or closed before
* being able to process the event
*/
def invokeWithFeedback(t: T): Future[Done]
} }
abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) { abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) {
@ -1414,53 +1585,3 @@ abstract class AbstractOutHandler extends OutHandler
* (completing when upstream completes, failing when upstream fails, completing when downstream cancels). * (completing when upstream completes, failing when upstream fails, completing when downstream cancels).
*/ */
abstract class AbstractInOutHandler extends InHandler with OutHandler abstract class AbstractInOutHandler extends InHandler with OutHandler
/**
* INTERNAL API
* This trait wraps callback for `GraphStage` stage instances and handle gracefully cases when stage is
* not yet initialized or already finished.
*
* While `GraphStage` has not initialized it adds all requests to list.
* As soon as `GraphStage` is started it stops collecting requests (pointing to real callback
* function) and run all the callbacks from the list
*
* Supposed to be used by GraphStages that share call back to outer world
*/
private[akka] trait CallbackWrapper[T] extends AsyncCallback[T] {
private trait CallbackState
private case class NotInitialized(list: List[T]) extends CallbackState
private case class Initialized(f: T Unit) extends CallbackState
private case class Stopped(f: T Unit) extends CallbackState
/*
* To preserve message order when switching between not initialized / initialized states
* lock is used. Case is similar to RepointableActorRef
*/
private[this] final val lock = new ReentrantLock
private[this] val callbackState = new AtomicReference[CallbackState](NotInitialized(Nil))
def stopCallback(f: T Unit): Unit = locked {
callbackState.set(Stopped(f))
}
def initCallback(f: T Unit): Unit = locked {
val list = (callbackState.getAndSet(Initialized(f)): @unchecked) match {
case NotInitialized(l) l
}
list.reverse.foreach(f)
}
override def invoke(arg: T): Unit = locked {
callbackState.get() match {
case Initialized(cb) cb(arg)
case list @ NotInitialized(l) callbackState.compareAndSet(list, NotInitialized(arg :: l))
case Stopped(cb) cb(arg)
}
}
private[this] def locked(body: Unit): Unit = {
lock.lock()
try body finally lock.unlock()
}
}