Automatic closing of graph stage sub inlets/outlets #29790

This commit is contained in:
Johan Andrén 2020-11-23 15:12:39 +01:00 committed by GitHub
parent 510e7374d5
commit c9980216a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 263 additions and 8 deletions

View file

@ -0,0 +1,204 @@
/*
* Copyright (C) 2009-2020 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.stream.impl
import akka.Done
import akka.NotUsed
import akka.dispatch.ExecutionContexts
import akka.stream.Attributes
import akka.stream.FlowShape
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.SinkShape
import akka.stream.SubscriptionWithCancelException.NoMoreElementsNeeded
import akka.stream.scaladsl.Sink
import akka.stream.scaladsl.Source
import akka.stream.stage.GraphStage
import akka.stream.stage.GraphStageLogic
import akka.stream.stage.InHandler
import akka.stream.stage.OutHandler
import akka.stream.testkit.StreamSpec
import akka.stream.testkit.TestPublisher
import akka.stream.testkit.TestSubscriber
import akka.stream.testkit.Utils.TE
import scala.util.Failure
import scala.util.Success
class SubInletOutletSpec extends StreamSpec {
"SubSinkInlet" should {
// a contrived custom graph stage just to observe what happens to the SubSinkInlet,
// it consumes commands from upstream telling it to fail or complete etc. and forwards elements from a side channel
// downstream through a SubSinkInlet
class PassAlongSubInStage(sideChannel: Source[String, NotUsed]) extends GraphStage[FlowShape[String, String]] {
val in = Inlet[String]("in")
val out = Outlet[String]("out")
@volatile var subCompletion: AnyRef = _
override val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
val subIn = new SubSinkInlet[String]("subin")
subIn.setHandler(new InHandler {
override def onPush(): Unit =
push(out, subIn.grab())
})
override def preStart(): Unit = {
sideChannel
.watchTermination() { (_, done) =>
done.onComplete(c => subCompletion = c)(ExecutionContexts.parasitic)
NotUsed
}
.runWith(Sink.fromGraph(subIn.sink))
pull(in) // eager pull of commands from upstream as downstream demand goes to subIn
}
setHandler(
in,
new InHandler {
override def onPush(): Unit = {
val cmd = grab(in)
// we never push to out here
cmd match {
case "completeStage" => completeStage()
case "cancelStage" => cancelStage(NoMoreElementsNeeded)
case "failStage" => failStage(TE("boom"))
case "closeAll" =>
cancel(in)
complete(out)
case _ => // ignore
}
if (isAvailable(in))
pull(in)
}
})
setHandler(out, new OutHandler {
override def onPull(): Unit = {
if (!subIn.hasBeenPulled)
subIn.pull()
}
})
}
}
class TestSetup {
val upstream = TestPublisher.probe[String]()
val sidechannel = TestPublisher.probe[String]()
val downstream = TestSubscriber.probe[String]()
val passAlong = new PassAlongSubInStage(Source.fromPublisher(sidechannel))
Source.fromPublisher(upstream).via(passAlong).runWith(Sink.fromSubscriber(downstream))
}
"complete automatically when parent stage completes" in new TestSetup {
downstream.request(1L)
sidechannel.expectRequest()
upstream.expectRequest()
sidechannel.sendNext("a one")
downstream.expectNext("a one")
upstream.sendNext("completeStage")
awaitAssert(passAlong.subCompletion should equal(Success(Done)))
}
"complete automatically when parent stage cancels" in new TestSetup {
downstream.request(1L)
sidechannel.expectRequest()
upstream.expectRequest()
sidechannel.sendNext("a one")
downstream.expectNext("a one")
upstream.sendNext("cancelStage")
awaitAssert(passAlong.subCompletion should equal(Success(Done)))
}
"fail automatically when parent stage fails" in new TestSetup {
downstream.request(1L)
sidechannel.expectRequest()
upstream.expectRequest()
sidechannel.sendNext("a one")
downstream.expectNext("a one")
upstream.sendNext("failStage")
awaitAssert(passAlong.subCompletion should equal(Failure(TE("boom"))))
}
"complete automatically when all parent ins and outs are closed" in new TestSetup {
downstream.request(1L)
sidechannel.expectRequest()
upstream.expectRequest()
sidechannel.sendNext("a one")
downstream.expectNext("a one")
upstream.sendNext("closeAll")
awaitAssert(passAlong.subCompletion should equal(Success(Done)))
}
}
"SubSourceOutlet" should {
// a contrived custom sink graph stage just to observe what happens to the SubSourceOutlet when its parent
// fails/completes
class ContrivedSubSourceStage extends GraphStage[SinkShape[String]] {
val in = Inlet[String]("in")
override val shape = SinkShape(in)
@volatile var subCompletion: AnyRef = _
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
val subOut = new SubSourceOutlet[String]("subout")
override def preStart(): Unit = {
Source
.fromGraph(subOut.source)
.runWith(Sink.ignore)
.onComplete(t => subCompletion = t)(ExecutionContexts.parasitic)
subOut.setHandler(new OutHandler {
override def onPull(): Unit = pull(in)
})
}
setHandler(in, new InHandler {
override def onPush(): Unit = {
val elem = grab(in)
elem match {
case "completeStage" => completeStage()
case "cancelStage" => cancelStage(NoMoreElementsNeeded)
case "failStage" => failStage(TE("boom"))
case "completeAll" => cancel(in)
case other => subOut.push(other)
}
}
})
}
}
"complete automatically when parent stage completes" in {
val stage = new ContrivedSubSourceStage
Source("element" :: "completeStage" :: Nil).runWith(Sink.fromGraph(stage))
awaitAssert(stage.subCompletion should equal(Success(Done)))
}
"complete automatically when parent stage cancels" in {
val stage = new ContrivedSubSourceStage
Source("element" :: "cancelStage" :: Nil).runWith(Sink.fromGraph(stage))
awaitAssert(stage.subCompletion should equal(Success(Done)))
}
"fail automatically when parent stage fails" in {
val stage = new ContrivedSubSourceStage
Source("element" :: "failStage" :: Nil).runWith(Sink.fromGraph(stage))
awaitAssert(stage.subCompletion should equal(Failure(TE("boom"))))
}
"cancel automatically when all parent ins and outs are closed" in {
val stage = new ContrivedSubSourceStage
Source("element" :: "completeAll" :: Nil).runWith(Sink.fromGraph(stage))
awaitAssert(stage.subCompletion should equal(Success(Done)))
}
}
}

View file

@ -774,7 +774,8 @@ import akka.util.ccompat.JavaConverters._
case null => case null =>
if (!status.compareAndSet(null, ActorSubscriberMessage.OnComplete)) if (!status.compareAndSet(null, ActorSubscriberMessage.OnComplete))
status.get.asInstanceOf[AsyncCallback[Any]].invoke(ActorSubscriberMessage.OnComplete) status.get.asInstanceOf[AsyncCallback[Any]].invoke(ActorSubscriberMessage.OnComplete)
case OnError(_: SubscriptionTimeoutException) => // already timed out, keep the timeout as that happened first case OnError(_) => // already failed out, keep the exception as that happened first
case ActorSubscriberMessage.OnComplete => // it was already completed
} }
def failSubstream(ex: Throwable): Unit = status.get match { def failSubstream(ex: Throwable): Unit = status.get match {
@ -783,7 +784,8 @@ import akka.util.ccompat.JavaConverters._
val failure = ActorSubscriberMessage.OnError(ex) val failure = ActorSubscriberMessage.OnError(ex)
if (!status.compareAndSet(null, failure)) if (!status.compareAndSet(null, failure))
status.get.asInstanceOf[AsyncCallback[Any]].invoke(failure) status.get.asInstanceOf[AsyncCallback[Any]].invoke(failure)
case OnError(_: SubscriptionTimeoutException) => // already timed out, keep the timeout as that happened first case ActorSubscriberMessage.OnComplete => // it was already completed, ignore failure as completion happened first
case OnError(_) => // already failed out, keep the exception as that happened first
} }
def timeout(d: FiniteDuration): Boolean = def timeout(d: FiniteDuration): Boolean =

View file

@ -725,9 +725,31 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
} }
i += 1 i += 1
} }
cleanUpSubstreams(optionalFailureCause)
setKeepGoing(false) setKeepGoing(false)
} }
private def cleanUpSubstreams(optionalFailureCause: OptionVal[Throwable]): Unit = {
_subInletsAndOutlets.foreach {
case inlet: SubSinkInlet[_] =>
val subSink = inlet.sink.asInstanceOf[SubSink[_]]
optionalFailureCause match {
case OptionVal.Some(cause) => subSink.cancelSubstream(cause)
case _ => subSink.cancelSubstream()
}
case outlet: SubSourceOutlet[_] =>
val subSource = outlet.source.asInstanceOf[SubSource[_]]
optionalFailureCause match {
case OptionVal.Some(cause) => subSource.failSubstream(cause)
case _ => subSource.completeSubstream()
}
case wat =>
throw new IllegalStateException(
s"Stage _subInletsAndOutlets contained unexpected element of type ${wat.getClass.toString}")
}
_subInletsAndOutlets = Set.empty
}
/** /**
* Return true if the given output port is ready to be pushed. * Return true if the given output port is ready to be pushed.
*/ */
@ -1253,6 +1275,22 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
case ref => ref case ref => ref
} }
// keep track of created SubSinkInlets and SubSourceOutlets to make sure we do not leak them
// when this stage completes/fails, not threadsafe only accessed from stream machinery callbacks etc.
private var _subInletsAndOutlets: Set[AnyRef] = Set.empty
private def created(inlet: SubSinkInlet[_]): Unit =
_subInletsAndOutlets += inlet
private def completedOrFailed(inlet: SubSinkInlet[_]): Unit =
_subInletsAndOutlets -= inlet
private def created(outlet: SubSourceOutlet[_]): Unit =
_subInletsAndOutlets += outlet
private def completedOrFailed(outlet: SubSourceOutlet[_]): Unit =
_subInletsAndOutlets -= outlet
/** /**
* Initialize a [[StageActorRef]] which can be used to interact with from the outside world "as-if" an [[Actor]]. * Initialize a [[StageActorRef]] which can be used to interact with from the outside world "as-if" an [[Actor]].
* The messages are looped through the [[getAsyncCallback]] mechanism of [[GraphStage]] so they are safe to modify * The messages are looped through the [[getAsyncCallback]] mechanism of [[GraphStage]] so they are safe to modify
@ -1329,6 +1367,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
val exception = streamDetachedException val exception = streamDetachedException
inProgress.foreach(_.tryFailure(exception)) inProgress.foreach(_.tryFailure(exception))
} }
cleanUpSubstreams(OptionVal.None)
} }
private[this] var asyncCleanupCounter = 0L private[this] var asyncCleanupCounter = 0L
@ -1375,8 +1414,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
* *
* This allows the dynamic creation of an Inlet for a GraphStage which is * This allows the dynamic creation of an Inlet for a GraphStage which is
* connected to a Sink that is available for materialization (e.g. using * connected to a Sink that is available for materialization (e.g. using
* the `subFusingMaterializer`). Care needs to be taken to cancel this Inlet * the `subFusingMaterializer`). Completion, cancellation and failure of the
* when the operator shuts down lest the corresponding Sink be left hanging. * parent operator is automatically delegated to instances of `SubSinkInlet` to avoid resource leaks.
* *
* To be thread safe this method must only be called from either the constructor of the graph operator during * To be thread safe this method must only be called from either the constructor of the graph operator during
* materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`.
@ -1404,6 +1443,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
} }
}.invoke _) }.invoke _)
GraphStageLogic.this.created(this)
def sink: Graph[SinkShape[T], NotUsed] = _sink def sink: Graph[SinkShape[T], NotUsed] = _sink
def setHandler(handler: InHandler): Unit = this.handler = handler def setHandler(handler: InHandler): Unit = this.handler = handler
@ -1429,10 +1470,13 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
_sink.pullSubstream() _sink.pullSubstream()
} }
def cancel(): Unit = cancel(SubscriptionWithCancelException.NoMoreElementsNeeded) def cancel(): Unit = {
cancel(SubscriptionWithCancelException.NoMoreElementsNeeded)
}
def cancel(cause: Throwable): Unit = { def cancel(cause: Throwable): Unit = {
closed = true closed = true
_sink.cancelSubstream(cause) _sink.cancelSubstream(cause)
GraphStageLogic.this.completedOrFailed(this)
} }
override def toString = s"SubSinkInlet($name)" override def toString = s"SubSinkInlet($name)"
@ -1443,9 +1487,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
* *
* This allows the dynamic creation of an Outlet for a GraphStage which is * This allows the dynamic creation of an Outlet for a GraphStage which is
* connected to a Source that is available for materialization (e.g. using * connected to a Source that is available for materialization (e.g. using
* the `subFusingMaterializer`). Care needs to be taken to complete this * the `subFusingMaterializer`). Completion, cancellation and failure of the
* Outlet when the operator shuts down lest the corresponding Sink be left * parent operator is automatically delegated to instances of `SubSourceOutlet`
* hanging. It is good practice to use the `timeout` method to cancel this * to avoid resource leaks.
*
* Even so it is good practice to use the `timeout` method to cancel this
* Outlet in case the corresponding Source is not materialized within a * Outlet in case the corresponding Source is not materialized within a
* given time limit, see e.g. ActorMaterializerSettings. * given time limit, see e.g. ActorMaterializerSettings.
* *
@ -1473,6 +1519,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
} }
private val _source = new SubSource[T](name, callback) private val _source = new SubSource[T](name, callback)
GraphStageLogic.this.created(this)
/** /**
* Set the source into timed-out mode if it has not yet been materialized. * Set the source into timed-out mode if it has not yet been materialized.
@ -1520,6 +1567,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
available = false available = false
closed = true closed = true
_source.completeSubstream() _source.completeSubstream()
GraphStageLogic.this.completedOrFailed(this)
} }
/** /**
@ -1529,6 +1577,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
available = false available = false
closed = true closed = true
_source.failSubstream(ex) _source.failSubstream(ex)
GraphStageLogic.this.completedOrFailed(this)
} }
override def toString = s"SubSourceOutlet($name)" override def toString = s"SubSourceOutlet($name)"