From 61bd3d42b6e4e931419b8c2342e809368866b9f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Fri, 22 Nov 2019 13:38:06 +0100 Subject: [PATCH] Document and fix getAsyncCallback thread safety issues #27999 --- .../scala/akka/remote/artery/Handshake.scala | 20 +++++----- .../remote/artery/SystemMessageDelivery.scala | 20 +++++----- .../scala/akka/stream/impl/fusing/Ops.scala | 12 +++--- .../scala/akka/stream/stage/GraphStage.scala | 40 ++++++++++++++----- 4 files changed, 57 insertions(+), 35 deletions(-) diff --git a/akka-remote/src/main/scala/akka/remote/artery/Handshake.scala b/akka-remote/src/main/scala/akka/remote/artery/Handshake.scala index 06ee7ce7a4..ec84721aca 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Handshake.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Handshake.scala @@ -7,7 +7,6 @@ package akka.remote.artery import scala.concurrent.duration._ import scala.concurrent.Future import scala.util.control.NoStackTrace - import akka.actor.ActorSystem import akka.remote.UniqueAddress import akka.stream.Attributes @@ -18,6 +17,7 @@ import akka.stream.stage._ import akka.util.{ unused, OptionVal } import akka.Done import akka.actor.Address +import akka.dispatch.ExecutionContexts /** * INTERNAL API @@ -128,6 +128,7 @@ private[remote] class OutboundHandshake( // when it receives the HandshakeRsp reply implicit val ec = materializer.executionContext uniqueRemoteAddress.foreach { + getAsyncCallback[UniqueAddress] { _ => if (handshakeState != Completed) { handshakeCompleted() @@ -218,6 +219,10 @@ private[remote] class InboundHandshake(inboundContext: InboundContext, inControl new TimerGraphStageLogic(shape) with OutHandler with StageLogging { import OutboundHandshake._ + private val runInStage = getAsyncCallback[() => Unit] { thunk => + thunk() + } + // InHandler if (inControlStream) setHandler( @@ -233,7 +238,7 @@ private[remote] class InboundHandshake(inboundContext: InboundContext, inControl // that the other system is alive. inboundContext.association(from.address).associationState.lastUsedTimestamp.set(System.nanoTime()) - after(inboundContext.completeHandshake(from)) { + after(inboundContext.completeHandshake(from)) { () => pull(in) } case _ => @@ -255,7 +260,7 @@ private[remote] class InboundHandshake(inboundContext: InboundContext, inControl private def onHandshakeReq(from: UniqueAddress, to: Address): Unit = { if (to == inboundContext.localAddress.address) { - after(inboundContext.completeHandshake(from)) { + after(inboundContext.completeHandshake(from)) { () => inboundContext.sendControl(from.address, HandshakeRsp(inboundContext.localAddress)) pull(in) } @@ -274,18 +279,15 @@ private[remote] class InboundHandshake(inboundContext: InboundContext, inControl } } - private def after(first: Future[Done])(thenInside: => Unit): Unit = { + private def after(first: Future[Done])(thenInside: () => Unit): Unit = { first.value match { case Some(_) => // This in the normal case (all but the first). The future will be completed // because handshake was already completed. Note that we send those HandshakeReq // periodically. - thenInside + thenInside() case None => - implicit val ec = materializer.executionContext - first.onComplete { _ => - getAsyncCallback[Done](_ => thenInside).invoke(Done) - } + first.onComplete(_ => runInStage.invoke(thenInside))(ExecutionContexts.sameThreadExecutionContext) } } diff --git a/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala b/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala index e8e6be2828..a7af4534ab 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala @@ -12,7 +12,6 @@ import scala.concurrent.duration._ import scala.util.Failure import scala.util.Success import scala.util.Try - import akka.Done import akka.remote.UniqueAddress import akka.remote.artery.InboundControlJunction.ControlMessageObserver @@ -28,9 +27,10 @@ import akka.stream.stage.TimerGraphStageLogic import akka.remote.artery.OutboundHandshake.HandshakeReq import akka.actor.ActorRef import akka.dispatch.sysmsg.SystemMessage -import scala.util.control.NoStackTrace +import scala.util.control.NoStackTrace import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts import akka.event.Logging import akka.stream.stage.StageLogging import akka.util.OptionVal @@ -104,14 +104,14 @@ import akka.util.OptionVal override protected def logSource: Class[_] = classOf[SystemMessageDelivery] override def preStart(): Unit = { - implicit val ec = materializer.executionContext - outboundContext.controlSubject.attach(this).foreach { - getAsyncCallback[Done] { _ => - replyObserverAttached = true - if (isAvailable(out)) - pull(in) // onPull from downstream already called - }.invoke + val callback = getAsyncCallback[Done] { _ => + replyObserverAttached = true + if (isAvailable(out)) + pull(in) // onPull from downstream already called } + outboundContext.controlSubject + .attach(this) + .foreach(callback.invoke)(ExecutionContexts.sameThreadExecutionContext) } override def postStop(): Unit = { @@ -151,7 +151,7 @@ import akka.util.OptionVal } } - // ControlMessageObserver, external call + // ControlMessageObserver, external call but on graph logic machinery thread (getAsyncCallback safe) override def controlSubjectCompleted(signal: Try[Done]): Unit = { getAsyncCallback[Try[Done]] { case Success(_) => completeStage() diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 897c87135f..0e5aa0228f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -1407,14 +1407,12 @@ private[stream] object Collect { new GraphStageLogic(shape) with InHandler with OutHandler with StageLogging { override protected def logSource: Class[_] = classOf[Watch[_]] - private lazy val self = getStageActor { - case (_, Terminated(`targetRef`)) => - failStage(new WatchedActorTerminatedException("Watch", targetRef)) - case (_, _) => // keep the compiler happy (stage actor receive is total) - } - override def preStart(): Unit = { - // initialize self, and watch the target + val self = getStageActor { + case (_, Terminated(`targetRef`)) => + failStage(new WatchedActorTerminatedException("Watch", targetRef)) + case (_, _) => // keep the compiler happy (stage actor receive is total) + } self.watch(targetRef) } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 37e6086ca9..c44e2835a9 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -1266,6 +1266,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * the Actor will be terminated as well. The entity backing the [[StageActorRef]] is not a real Actor, * but the [[GraphStageLogic]] itself, therefore it does not react to [[PoisonPill]]. * + * To be thread safe this method must only be called from either the constructor of the graph operator during + * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. + * * @param receive callback that will be called upon receiving of a message by this special Actor * @return minimal actor with watch method */ @@ -1274,6 +1277,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * INTERNAL API + * + * To be thread safe this method must only be called from either the constructor of the graph operator during + * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. */ @InternalApi protected[akka] def getEagerStageActor( @@ -1370,6 +1376,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * connected to a Sink that is available for materialization (e.g. using * the `subFusingMaterializer`). Care needs to be taken to cancel this Inlet * when the operator shuts down lest the corresponding Sink be left hanging. + * + * To be thread safe this method must only be called from either the constructor of the graph operator during + * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. */ class SubSinkInlet[T](name: String) { import ActorSubscriberMessage._ @@ -1438,6 +1447,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * hanging. It is good practice to use the `timeout` method to cancel this * Outlet in case the corresponding Source is not materialized within a * given time limit, see e.g. ActorMaterializerSettings. + * + * To be thread safe this method must only be called from either the constructor of the graph operator during + * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. */ class SubSourceOutlet[T](name: String) { @@ -1560,6 +1572,12 @@ trait AsyncCallback[T] { def invokeWithFeedback(t: T): Future[Done] } +/** + * Provides timer related facilities to a [[GraphStageLogic]]. + * + * To be thread safe the methods of this class must only be called from either the constructor of the graph operator during + * materialization or one of the methods invoked by the graph operator machinery, such as `onPush` and `onPull`. + */ abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) { import TimerMessages._ @@ -1610,9 +1628,9 @@ abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shap final protected def scheduleOnce(timerKey: Any, delay: FiniteDuration): Unit = { cancelTimer(timerKey) val id = timerIdGen.next() - val task = interpreter.materializer.scheduleOnce(delay, new Runnable { - def run() = getTimerAsyncCallback.invoke(Scheduled(timerKey, id, repeating = false)) - }) + val callback = getTimerAsyncCallback + val task = + interpreter.materializer.scheduleOnce(delay, () => callback.invoke(Scheduled(timerKey, id, repeating = false))) keyToTimers(timerKey) = Timer(id, task) } @@ -1638,9 +1656,11 @@ abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shap delay: FiniteDuration): Unit = { cancelTimer(timerKey) val id = timerIdGen.next() - val task = interpreter.materializer.scheduleWithFixedDelay(initialDelay, delay, new Runnable { - def run() = getTimerAsyncCallback.invoke(Scheduled(timerKey, id, repeating = true)) - }) + val callback = getTimerAsyncCallback + val task = interpreter.materializer.scheduleWithFixedDelay( + initialDelay, + delay, + () => callback.invoke(Scheduled(timerKey, id, repeating = true))) keyToTimers(timerKey) = Timer(id, task) } @@ -1670,9 +1690,11 @@ abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shap interval: FiniteDuration): Unit = { cancelTimer(timerKey) val id = timerIdGen.next() - val task = interpreter.materializer.scheduleAtFixedRate(initialDelay, interval, new Runnable { - def run() = getTimerAsyncCallback.invoke(Scheduled(timerKey, id, repeating = true)) - }) + val callback = getTimerAsyncCallback + val task = interpreter.materializer.scheduleAtFixedRate( + initialDelay, + interval, + () => callback.invoke(Scheduled(timerKey, id, repeating = true))) keyToTimers(timerKey) = Timer(id, task) }