From 3dda73c1eaeb7325cc32d2715dce5eb5e4b44894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Tue, 5 Dec 2017 14:07:10 +0100 Subject: [PATCH] Async callback memory leak fix #24046 --- .../stream/InvokeWithFeedbackBenchmark.scala | 62 +++++++++ .../impl/fusing/AsyncCallbackSpec.scala | 58 +++++++- .../scala/akka/stream/stage/GraphStage.scala | 130 +++++++++++------- 3 files changed, 200 insertions(+), 50 deletions(-) create mode 100644 akka-bench-jmh/src/main/scala/akka/stream/InvokeWithFeedbackBenchmark.scala diff --git a/akka-bench-jmh/src/main/scala/akka/stream/InvokeWithFeedbackBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/InvokeWithFeedbackBenchmark.scala new file mode 100644 index 0000000000..5a1ef1426d --- /dev/null +++ b/akka-bench-jmh/src/main/scala/akka/stream/InvokeWithFeedbackBenchmark.scala @@ -0,0 +1,62 @@ +/** + * Copyright (C) 2014-2017 Lightbend Inc. + */ + +package akka.stream + +import java.util.concurrent.TimeUnit + +import akka.actor.ActorSystem +import akka.stream.scaladsl._ +import org.openjdk.jmh.annotations._ + +import scala.concurrent._ +import scala.concurrent.duration._ + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class InvokeWithFeedbackBenchmark { + implicit val system = ActorSystem("InvokeWithFeedbackBenchmark") + val materializerSettings = ActorMaterializerSettings(system).withDispatcher("akka.test.stream-dispatcher") + + var sourceQueue: SourceQueueWithComplete[Int] = _ + var sinkQueue: SinkQueueWithCancel[Int] = _ + + val waitForResult = 100.millis + + @Setup + def setup(): Unit = { + val settings = ActorMaterializerSettings(system) + + implicit val materializer = ActorMaterializer(settings) + + // these are currently the only two built in stages using invokeWithFeedback + val (in, out) = + Source.queue[Int](bufferSize = 1, overflowStrategy = OverflowStrategies.Backpressure) + .toMat(Sink.queue[Int]())(Keep.both) + .run() + + sourceQueue = in + sinkQueue = out + + } + + @OperationsPerInvocation(100000) + @Benchmark + def pass_through_100k_elements(): Unit = { + (0 to 100000).foreach { n ⇒ + val f = sinkQueue.pull() + Await.result(sourceQueue.offer(n), waitForResult) + Await.result(f, waitForResult) + } + } + + @TearDown + def tearDown(): Unit = { + sourceQueue.complete() + // no way to observe sink completion from the outside + Await.result(system.terminate(), 5.seconds) + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala index a31a79efb5..55e3a188ec 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala @@ -5,14 +5,15 @@ package akka.stream.impl.fusing import akka.Done import akka.actor.ActorRef -import akka.stream.stage._ import akka.stream._ import akka.stream.scaladsl.{ Keep, Sink, Source } +import akka.stream.stage._ import akka.stream.testkit.Utils.TE import akka.stream.testkit.{ TestPublisher, TestSubscriber } import akka.testkit.{ AkkaSpec, TestProbe } -import scala.concurrent.{ Future, Promise } +import scala.concurrent.{ Await, Future, Promise } +import scala.language.reflectiveCalls class AsyncCallbackSpec extends AkkaSpec { @@ -235,7 +236,58 @@ class AsyncCallbackSpec extends AkkaSpec { val feedbakF = callback.invokeWithFeedback("fail-the-stage") val failure = feedbakF.failed.futureValue - failure shouldBe a[StreamDetachedException] // we can't capture the exception in this case + failure shouldBe a[StreamDetachedException] + } + + "behave with multiple async callbacks" in { + import system.dispatcher + + class ManyAsyncCallbacksStage(probe: ActorRef) extends GraphStageWithMaterializedValue[SourceShape[String], Set[AsyncCallback[AnyRef]]] { + val out = Outlet[String]("out") + val shape = SourceShape(out) + def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val logic = new GraphStageLogic(shape) { + val callbacks = (0 to 10).map(_ ⇒ getAsyncCallback[AnyRef](probe ! _)).toSet + setHandler(out, new OutHandler { + def onPull(): Unit = () + }) + } + (logic, logic.callbacks) + } + } + + val acbProbe = TestProbe() + + val out = TestSubscriber.probe[String]() + + val acbs = Source.fromGraph(new ManyAsyncCallbacksStage(acbProbe.ref)) + .toMat(Sink.fromSubscriber(out))(Keep.left) + .run() + + val happyPathFeedbacks = + acbs.map(acb ⇒ + Future { acb.invokeWithFeedback("bö") }.flatMap(identity) + ) + Future.sequence(happyPathFeedbacks).futureValue // will throw on fail or timeout on not completed + + for (_ ← 0 to 10) acbProbe.expectMsg("bö") + + val (half, otherHalf) = acbs.splitAt(4) + val firstHalfFutures = half.map(_.invokeWithFeedback("ba")) + out.cancel() // cancel in the middle + val otherHalfFutures = otherHalf.map(_.invokeWithFeedback("ba")) + val unhappyPath = firstHalfFutures ++ otherHalfFutures + + // all futures should either be completed or failed with StreamDetachedException + unhappyPath.foreach { future ⇒ + try { + val done = Await.result(future, remainingOrDefault) + done should ===(Done) + } catch { + case _: StreamDetachedException ⇒ // this is fine + } + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 460875c7be..422b5e2543 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -3,29 +3,24 @@ */ package akka.stream.stage -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.locks.ReentrantLock -import akka.{ Done, NotUsed } import akka.actor._ import akka.annotation.{ ApiMayChange, InternalApi } -import akka.dispatch.ExecutionContexts.sameThreadExecutionContext +import akka.dispatch.ExecutionContexts import akka.japi.function.{ Effect, Procedure } import akka.stream._ import akka.stream.actor.ActorSubscriberMessage -import akka.stream.impl.{ NotInitialized, ReactiveStreamsCompliance, TraversalBuilder } import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource } +import akka.stream.impl.{ ReactiveStreamsCompliance, TraversalBuilder } import akka.stream.scaladsl.GenericGraphWithChangedAttributes import akka.util.OptionVal +import akka.{ Done, NotUsed } import scala.annotation.tailrec -import scala.collection.JavaConverters._ import scala.collection.{ immutable, mutable } -import scala.concurrent.{ Future, Promise } import scala.concurrent.duration.FiniteDuration -import scala.util.Try -import scala.util.control.TailCalls.{ TailRec, done, tailcall } +import scala.concurrent.{ Future, Promise } /** * Scala API: A GraphStage represents a reusable graph stream processing stage. @@ -1005,16 +1000,16 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed * due to stream completion. * - * Method must be called in thread-safe manner during materialization and in the same thread as materialization - * process. + * To be thread safe this method must only be called from either the constructor of the graph stage during + * materialization or one of the methods invoked by the graph stage machinery, such as `onPush` and `onPull`. * * This object can be cached and reused within the same [[GraphStageLogic]]. */ final def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { - val result = new ConcurrentAsyncCallback[T](handler) - asyncCallbacksInProgress.add(result) - if (_interpreter != null) result.onStart() - result + val callback = new ConcurrentAsyncCallback[T](handler) + if (_interpreter != null) callback.onStart() + else callbacksWaitingForInterpreter = callback :: callbacksWaitingForInterpreter + callback } /** @@ -1039,7 +1034,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * [[onStop()]] puts class in `Completed` state * "Real world" calls of [[invokeWithFeedback()]] always return failed promises for `Completed` state */ - private class ConcurrentAsyncCallback[T](handler: T ⇒ Unit) extends AsyncCallback[T] { + private final class ConcurrentAsyncCallback[T](handler: T ⇒ Unit) extends AsyncCallback[T] { sealed trait State // waiting for materialization completion @@ -1048,38 +1043,39 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private case object Initializing extends State // stream is initialized and so no threads can just send events without any synchronization overhead private case object Initialized extends State - // stage has been shut down, either regularly or it failed - private case object Completed extends State // Event with feedback promise private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]]) - val waitingForProcessing = ConcurrentHashMap.newKeySet[Promise[_]]() - private[this] val currentState = new AtomicReference[State](Pending(Nil)) // is called from the owning [[GraphStage]] @tailrec - final private[stage] def onStart(): Unit = { + private[stage] def onStart(): Unit = { (currentState.getAndSet(Initializing): @unchecked) match { case Pending(l) ⇒ l.reverse.foreach(evt ⇒ { + evt.handlingPromise match { + case OptionVal.Some(p) ⇒ p.future.onComplete(_ ⇒ onFeedbackCompleted(p))(ExecutionContexts.sameThreadExecutionContext) + case OptionVal.None ⇒ // buffered invoke without promise + } onAsyncInput(evt.e, evt.handlingPromise) }) } if (!currentState.compareAndSet(Initializing, Initialized)) { (currentState.get: @unchecked) match { case Pending(_) ⇒ onStart() - case Completed ⇒ () //wonder if this is possible } } } // is called from the owning [[GraphStage]] - final private[stage] def onStop(): Unit = { - currentState.set(Completed) - val iterator = waitingForProcessing.iterator() - lazy val detachedException = new StreamDetachedException() - while (iterator.hasNext) { iterator.next().tryFailure(detachedException) } - waitingForProcessing.clear() + private[stage] def onStop(outstandingPromises: Set[Promise[Done]]): Unit = { + if (outstandingPromises.nonEmpty) { + val detachedException = streamDetatchedException + val iterator = outstandingPromises.iterator + while (iterator.hasNext) { + iterator.next().tryFailure(detachedException) + } + } } private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = { @@ -1088,29 +1084,59 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = { onAsyncInput(event, OptionVal.Some(promise)) - currentState.get() match { - case Completed ⇒ failPromiseOnComplete(promise) - case _ ⇒ promise - } + if (stopped) failPromiseOnComplete(promise) + else promise } // external call override def invokeWithFeedback(event: T): Future[Done] = { val promise: Promise[Done] = Promise[Done]() - promise.future.andThen { case _ ⇒ waitingForProcessing.remove(promise) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) - waitingForProcessing.add(promise) - invokeWithPromise(event, promise).future + promise.future.onComplete(_ ⇒ onFeedbackCompleted(promise))(ExecutionContexts.sameThreadExecutionContext) + + @tailrec + def addToWaiting(): Boolean = { + val previous = asyncCallbacksInProgress.get() + if (previous != null) { + val updated = previous.updated(this, previous(this) + promise) + if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting() + else true + } else { + false + } + } + + if (addToWaiting()) + invokeWithPromise(event, promise).future + else + Future.failed(streamDetatchedException) } - private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = { + // removes the promise from the callbacks in promise on complete, called from onComplete + private def onFeedbackCompleted(promise: Promise[Done]): Unit = { + @tailrec + def removeFromWaiting(): Unit = { + val previous = asyncCallbacksInProgress.get() + if (previous != null) { + val newSet = previous(this) - promise + val updated = + if (newSet.isEmpty) previous - this // no outstanding promises, remove stage from map to avoid leak + else previous.updated(this, newSet) + if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) removeFromWaiting() + } + } + removeFromWaiting() + } + + @tailrec + private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = currentState.get() match { + // started - can just send message to stream + case Initialized ⇒ sendEvent(event, promise) // not started yet case list @ Pending(_) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents))) invokeWithPromise(event, promise) // atomicity is failed - try again else promise - // started - can just send message to stream - case Initialized ⇒ sendEvent(event, promise) // initializing is in progress in another thread (initializing thread is managed by Akka) case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) { (currentState.get(): @unchecked) match { @@ -1118,14 +1144,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: case Initialized ⇒ sendEvent(event, promise) } } else promise - // fail promise as stream is completed - case Completed ⇒ failPromiseOnComplete(promise) } - } private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { - waitingForProcessing.remove(promise) - promise.tryFailure(new StreamDetachedException("Stage stopped before async invocation was processed")) + onFeedbackCompleted(promise) + promise.tryFailure(streamDetatchedException) promise } @@ -1144,10 +1167,13 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: case Initialized ⇒ onAsyncInput(event, OptionVal.None) } } - case Completed ⇒ // do nothing here as stream is completed } internalInvoke(event) } + + private def streamDetatchedException = + new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed") + } /** @@ -1164,7 +1190,13 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] = getAsyncCallback(handler.apply) - private val asyncCallbacksInProgress = mutable.HashSet[ConcurrentAsyncCallback[_]]() + private var callbacksWaitingForInterpreter: List[ConcurrentAsyncCallback[_]] = Nil + // is used for two purposes: keep track of running callbacks and signal that the + // stage has stopped to fail incoming async callback invocations by being set to null + private val asyncCallbacksInProgress = + new AtomicReference(Map.empty[ConcurrentAsyncCallback[_], Set[Promise[Done]]].withDefaultValue(Set.empty)) + + private def stopped = asyncCallbacksInProgress.get() == null private var _stageActor: StageActor = _ final def stageActor: StageActor = _stageActor match { @@ -1206,7 +1238,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // Internal hooks to avoid reliance on user calling super in preStart /** INTERNAL API */ protected[stream] def beforePreStart(): Unit = { - asyncCallbacksInProgress.foreach(_.onStart()) + callbacksWaitingForInterpreter.foreach(_.onStart()) + callbacksWaitingForInterpreter = Nil } // Internal hooks to avoid reliance on user calling super in postStop @@ -1216,7 +1249,10 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: _stageActor.stop() _stageActor = null } - asyncCallbacksInProgress.foreach(_.onStop()) + // make sure any invokeWithFeedback after this fails fast + // and fail current outstanding invokeWithFeedback promises + val inProgress = asyncCallbacksInProgress.getAndSet(null) + inProgress.foreach { case (acb, promises) ⇒ acb.onStop(promises) } } /**