Async callback memory leak fix #24046

This commit is contained in:
Johan Andrén 2017-12-05 14:07:10 +01:00 committed by GitHub
parent fa953e60f7
commit 3dda73c1ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 50 deletions

View file

@ -0,0 +1,62 @@
/**
* Copyright (C) 2014-2017 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.stream
import java.util.concurrent.TimeUnit
import akka.actor.ActorSystem
import akka.stream.scaladsl._
import org.openjdk.jmh.annotations._
import scala.concurrent._
import scala.concurrent.duration._
@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
class InvokeWithFeedbackBenchmark {
implicit val system = ActorSystem("InvokeWithFeedbackBenchmark")
val materializerSettings = ActorMaterializerSettings(system).withDispatcher("akka.test.stream-dispatcher")
var sourceQueue: SourceQueueWithComplete[Int] = _
var sinkQueue: SinkQueueWithCancel[Int] = _
val waitForResult = 100.millis
@Setup
def setup(): Unit = {
val settings = ActorMaterializerSettings(system)
implicit val materializer = ActorMaterializer(settings)
// these are currently the only two built in stages using invokeWithFeedback
val (in, out) =
Source.queue[Int](bufferSize = 1, overflowStrategy = OverflowStrategies.Backpressure)
.toMat(Sink.queue[Int]())(Keep.both)
.run()
sourceQueue = in
sinkQueue = out
}
@OperationsPerInvocation(100000)
@Benchmark
def pass_through_100k_elements(): Unit = {
(0 to 100000).foreach { n
val f = sinkQueue.pull()
Await.result(sourceQueue.offer(n), waitForResult)
Await.result(f, waitForResult)
}
}
@TearDown
def tearDown(): Unit = {
sourceQueue.complete()
// no way to observe sink completion from the outside
Await.result(system.terminate(), 5.seconds)
}
}

View file

@ -5,14 +5,15 @@ package akka.stream.impl.fusing
import akka.Done
import akka.actor.ActorRef
import akka.stream.stage._
import akka.stream._
import akka.stream.scaladsl.{ Keep, Sink, Source }
import akka.stream.stage._
import akka.stream.testkit.Utils.TE
import akka.stream.testkit.{ TestPublisher, TestSubscriber }
import akka.testkit.{ AkkaSpec, TestProbe }
import scala.concurrent.{ Future, Promise }
import scala.concurrent.{ Await, Future, Promise }
import scala.language.reflectiveCalls
class AsyncCallbackSpec extends AkkaSpec {
@ -235,7 +236,58 @@ class AsyncCallbackSpec extends AkkaSpec {
val feedbakF = callback.invokeWithFeedback("fail-the-stage")
val failure = feedbakF.failed.futureValue
failure shouldBe a[StreamDetachedException] // we can't capture the exception in this case
failure shouldBe a[StreamDetachedException]
}
"behave with multiple async callbacks" in {
import system.dispatcher
class ManyAsyncCallbacksStage(probe: ActorRef) extends GraphStageWithMaterializedValue[SourceShape[String], Set[AsyncCallback[AnyRef]]] {
val out = Outlet[String]("out")
val shape = SourceShape(out)
def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val logic = new GraphStageLogic(shape) {
val callbacks = (0 to 10).map(_ getAsyncCallback[AnyRef](probe ! _)).toSet
setHandler(out, new OutHandler {
def onPull(): Unit = ()
})
}
(logic, logic.callbacks)
}
}
val acbProbe = TestProbe()
val out = TestSubscriber.probe[String]()
val acbs = Source.fromGraph(new ManyAsyncCallbacksStage(acbProbe.ref))
.toMat(Sink.fromSubscriber(out))(Keep.left)
.run()
val happyPathFeedbacks =
acbs.map(acb
Future { acb.invokeWithFeedback("bö") }.flatMap(identity)
)
Future.sequence(happyPathFeedbacks).futureValue // will throw on fail or timeout on not completed
for (_ 0 to 10) acbProbe.expectMsg("bö")
val (half, otherHalf) = acbs.splitAt(4)
val firstHalfFutures = half.map(_.invokeWithFeedback("ba"))
out.cancel() // cancel in the middle
val otherHalfFutures = otherHalf.map(_.invokeWithFeedback("ba"))
val unhappyPath = firstHalfFutures ++ otherHalfFutures
// all futures should either be completed or failed with StreamDetachedException
unhappyPath.foreach { future
try {
val done = Await.result(future, remainingOrDefault)
done should ===(Done)
} catch {
case _: StreamDetachedException // this is fine
}
}
}
}

View file

@ -3,29 +3,24 @@
*/
package akka.stream.stage
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.ReentrantLock
import akka.{ Done, NotUsed }
import akka.actor._
import akka.annotation.{ ApiMayChange, InternalApi }
import akka.dispatch.ExecutionContexts.sameThreadExecutionContext
import akka.dispatch.ExecutionContexts
import akka.japi.function.{ Effect, Procedure }
import akka.stream._
import akka.stream.actor.ActorSubscriberMessage
import akka.stream.impl.{ NotInitialized, ReactiveStreamsCompliance, TraversalBuilder }
import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule, SubSink, SubSource }
import akka.stream.impl.{ ReactiveStreamsCompliance, TraversalBuilder }
import akka.stream.scaladsl.GenericGraphWithChangedAttributes
import akka.util.OptionVal
import akka.{ Done, NotUsed }
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.{ immutable, mutable }
import scala.concurrent.{ Future, Promise }
import scala.concurrent.duration.FiniteDuration
import scala.util.Try
import scala.util.control.TailCalls.{ TailRec, done, tailcall }
import scala.concurrent.{ Future, Promise }
/**
* Scala API: A GraphStage represents a reusable graph stream processing stage.
@ -1005,16 +1000,16 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
* [[AsyncCallback.invokeWithFeedback()]] has an internal promise that will be failed if event cannot be processed
* due to stream completion.
*
* Method must be called in thread-safe manner during materialization and in the same thread as materialization
* process.
* To be thread safe this method must only be called from either the constructor of the graph stage during
* materialization or one of the methods invoked by the graph stage machinery, such as `onPush` and `onPull`.
*
* This object can be cached and reused within the same [[GraphStageLogic]].
*/
final def getAsyncCallback[T](handler: T Unit): AsyncCallback[T] = {
val result = new ConcurrentAsyncCallback[T](handler)
asyncCallbacksInProgress.add(result)
if (_interpreter != null) result.onStart()
result
val callback = new ConcurrentAsyncCallback[T](handler)
if (_interpreter != null) callback.onStart()
else callbacksWaitingForInterpreter = callback :: callbacksWaitingForInterpreter
callback
}
/**
@ -1039,7 +1034,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
* [[onStop()]] puts class in `Completed` state
* "Real world" calls of [[invokeWithFeedback()]] always return failed promises for `Completed` state
*/
private class ConcurrentAsyncCallback[T](handler: T Unit) extends AsyncCallback[T] {
private final class ConcurrentAsyncCallback[T](handler: T Unit) extends AsyncCallback[T] {
sealed trait State
// waiting for materialization completion
@ -1048,38 +1043,39 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
private case object Initializing extends State
// stream is initialized and so no threads can just send events without any synchronization overhead
private case object Initialized extends State
// stage has been shut down, either regularly or it failed
private case object Completed extends State
// Event with feedback promise
private case class Event(e: T, handlingPromise: OptionVal[Promise[Done]])
val waitingForProcessing = ConcurrentHashMap.newKeySet[Promise[_]]()
private[this] val currentState = new AtomicReference[State](Pending(Nil))
// is called from the owning [[GraphStage]]
@tailrec
final private[stage] def onStart(): Unit = {
private[stage] def onStart(): Unit = {
(currentState.getAndSet(Initializing): @unchecked) match {
case Pending(l) l.reverse.foreach(evt {
evt.handlingPromise match {
case OptionVal.Some(p) p.future.onComplete(_ onFeedbackCompleted(p))(ExecutionContexts.sameThreadExecutionContext)
case OptionVal.None // buffered invoke without promise
}
onAsyncInput(evt.e, evt.handlingPromise)
})
}
if (!currentState.compareAndSet(Initializing, Initialized)) {
(currentState.get: @unchecked) match {
case Pending(_) onStart()
case Completed () //wonder if this is possible
}
}
}
// is called from the owning [[GraphStage]]
final private[stage] def onStop(): Unit = {
currentState.set(Completed)
val iterator = waitingForProcessing.iterator()
lazy val detachedException = new StreamDetachedException()
while (iterator.hasNext) { iterator.next().tryFailure(detachedException) }
waitingForProcessing.clear()
private[stage] def onStop(outstandingPromises: Set[Promise[Done]]): Unit = {
if (outstandingPromises.nonEmpty) {
val detachedException = streamDetatchedException
val iterator = outstandingPromises.iterator
while (iterator.hasNext) {
iterator.next().tryFailure(detachedException)
}
}
}
private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = {
@ -1088,29 +1084,59 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = {
onAsyncInput(event, OptionVal.Some(promise))
currentState.get() match {
case Completed failPromiseOnComplete(promise)
case _ promise
}
if (stopped) failPromiseOnComplete(promise)
else promise
}
// external call
override def invokeWithFeedback(event: T): Future[Done] = {
val promise: Promise[Done] = Promise[Done]()
promise.future.andThen { case _ waitingForProcessing.remove(promise) }(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
waitingForProcessing.add(promise)
invokeWithPromise(event, promise).future
promise.future.onComplete(_ onFeedbackCompleted(promise))(ExecutionContexts.sameThreadExecutionContext)
@tailrec
def addToWaiting(): Boolean = {
val previous = asyncCallbacksInProgress.get()
if (previous != null) {
val updated = previous.updated(this, previous(this) + promise)
if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) addToWaiting()
else true
} else {
false
}
}
if (addToWaiting())
invokeWithPromise(event, promise).future
else
Future.failed(streamDetatchedException)
}
private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] = {
// removes the promise from the callbacks in promise on complete, called from onComplete
private def onFeedbackCompleted(promise: Promise[Done]): Unit = {
@tailrec
def removeFromWaiting(): Unit = {
val previous = asyncCallbacksInProgress.get()
if (previous != null) {
val newSet = previous(this) - promise
val updated =
if (newSet.isEmpty) previous - this // no outstanding promises, remove stage from map to avoid leak
else previous.updated(this, newSet)
if (!asyncCallbacksInProgress.compareAndSet(previous, updated)) removeFromWaiting()
}
}
removeFromWaiting()
}
@tailrec
private def invokeWithPromise(event: T, promise: Promise[Done]): Promise[Done] =
currentState.get() match {
// started - can just send message to stream
case Initialized sendEvent(event, promise)
// not started yet
case list @ Pending(_)
if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal(promise)) :: list.pendingEvents)))
invokeWithPromise(event, promise) // atomicity is failed - try again
else promise
// started - can just send message to stream
case Initialized sendEvent(event, promise)
// initializing is in progress in another thread (initializing thread is managed by Akka)
case Initializing if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal(promise)) :: Nil))) {
(currentState.get(): @unchecked) match {
@ -1118,14 +1144,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
case Initialized sendEvent(event, promise)
}
} else promise
// fail promise as stream is completed
case Completed failPromiseOnComplete(promise)
}
}
private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = {
waitingForProcessing.remove(promise)
promise.tryFailure(new StreamDetachedException("Stage stopped before async invocation was processed"))
onFeedbackCompleted(promise)
promise.tryFailure(streamDetatchedException)
promise
}
@ -1144,10 +1167,13 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
case Initialized onAsyncInput(event, OptionVal.None)
}
}
case Completed // do nothing here as stream is completed
}
internalInvoke(event)
}
private def streamDetatchedException =
new StreamDetachedException(s"Stage with GraphStageLogic ${this} stopped before async invocation was processed")
}
/**
@ -1164,7 +1190,13 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] =
getAsyncCallback(handler.apply)
private val asyncCallbacksInProgress = mutable.HashSet[ConcurrentAsyncCallback[_]]()
private var callbacksWaitingForInterpreter: List[ConcurrentAsyncCallback[_]] = Nil
// is used for two purposes: keep track of running callbacks and signal that the
// stage has stopped to fail incoming async callback invocations by being set to null
private val asyncCallbacksInProgress =
new AtomicReference(Map.empty[ConcurrentAsyncCallback[_], Set[Promise[Done]]].withDefaultValue(Set.empty))
private def stopped = asyncCallbacksInProgress.get() == null
private var _stageActor: StageActor = _
final def stageActor: StageActor = _stageActor match {
@ -1206,7 +1238,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
// Internal hooks to avoid reliance on user calling super in preStart
/** INTERNAL API */
protected[stream] def beforePreStart(): Unit = {
asyncCallbacksInProgress.foreach(_.onStart())
callbacksWaitingForInterpreter.foreach(_.onStart())
callbacksWaitingForInterpreter = Nil
}
// Internal hooks to avoid reliance on user calling super in postStop
@ -1216,7 +1249,10 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
_stageActor.stop()
_stageActor = null
}
asyncCallbacksInProgress.foreach(_.onStop())
// make sure any invokeWithFeedback after this fails fast
// and fail current outstanding invokeWithFeedback promises
val inProgress = asyncCallbacksInProgress.getAndSet(null)
inProgress.foreach { case (acb, promises) acb.onStop(promises) }
}
/**