diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala new file mode 100644 index 0000000000..a31a79efb5 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/AsyncCallbackSpec.scala @@ -0,0 +1,242 @@ +/** + * Copyright (C) 2009-2017 Lightbend Inc. + */ +package akka.stream.impl.fusing + +import akka.Done +import akka.actor.ActorRef +import akka.stream.stage._ +import akka.stream._ +import akka.stream.scaladsl.{ Keep, Sink, Source } +import akka.stream.testkit.Utils.TE +import akka.stream.testkit.{ TestPublisher, TestSubscriber } +import akka.testkit.{ AkkaSpec, TestProbe } + +import scala.concurrent.{ Future, Promise } + +class AsyncCallbackSpec extends AkkaSpec { + + implicit val materializer = ActorMaterializer(ActorMaterializerSettings(system).withFuzzing(false)) + + case object Started + case class Elem(n: Int) + case object Stopped + + class AsyncCallbackGraphStage(probe: ActorRef, early: Option[AsyncCallback[AnyRef] ⇒ Unit] = None) + extends GraphStageWithMaterializedValue[FlowShape[Int, Int], AsyncCallback[AnyRef]] { + + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + val shape = FlowShape(in, out) + + def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, AsyncCallback[AnyRef]) = { + val logic = new GraphStageLogic(shape) { + val callback = getAsyncCallback((whatever: AnyRef) ⇒ { + whatever match { + case t: Throwable ⇒ throw t + case "fail-the-stage" ⇒ failStage(new RuntimeException("failing the stage")) + case anythingElse ⇒ probe ! anythingElse + } + }) + early.foreach(cb ⇒ cb(callback)) + + override def preStart(): Unit = { + probe ! Started + } + + override def postStop(): Unit = { + probe ! Stopped + } + + setHandlers(in, out, new InHandler with OutHandler { + def onPush(): Unit = { + val n = grab(in) + probe ! Elem(n) + push(out, n) + } + + def onPull(): Unit = { + pull(in) + } + }) + } + + (logic, logic.callback) + } + } + + "The support for async callbacks" must { + + "invoke without feedback, happy path" in { + val probe = TestProbe() + val in = TestPublisher.probe[Int]() + val out = TestSubscriber.probe[Int]() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.fromSubscriber(out)) + .run() + + probe.expectMsg(Started) + out.request(1) + in.expectRequest() + + (0 to 10).foreach { n ⇒ + val msg = "whatever" + n + callback.invoke(msg) + probe.expectMsg(msg) + } + + in.sendComplete() + out.expectComplete() + + probe.expectMsg(Stopped) + } + + "invoke with feedback, happy path" in { + val probe = TestProbe() + val in = TestPublisher.probe[Int]() + val out = TestSubscriber.probe[Int]() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.fromSubscriber(out)) + .run() + + probe.expectMsg(Started) + out.request(1) + in.expectRequest() + + (0 to 10).foreach { n ⇒ + val msg = "whatever" + n + val feedbackF = callback.invokeWithFeedback(msg) + probe.expectMsg(msg) + feedbackF.futureValue should ===(Done) + } + in.sendComplete() + out.expectComplete() + + probe.expectMsg(Stopped) + } + + "fail the feedback future if stage is stopped" in { + val probe = TestProbe() + val callback = Source.empty + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.ignore) + .run() + + probe.expectMsg(Started) + probe.expectMsg(Stopped) + + val feedbakF = callback.invokeWithFeedback("whatever") + feedbakF.failed.futureValue shouldBe a[StreamDetachedException] + } + + "invoke early" in { + val probe = TestProbe() + val in = TestPublisher.probe[Int]() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage( + probe.ref, + Some(asyncCb ⇒ asyncCb.invoke("early")) + ))(Keep.right) + .to(Sink.ignore) + .run() + + // and deliver in order + callback.invoke("later") + + probe.expectMsg(Started) + probe.expectMsg("early") + probe.expectMsg("later") + + in.sendComplete() + probe.expectMsg(Stopped) + + } + + "invoke with feedback early" in { + val probe = TestProbe() + val earlyFeedback = Promise[Done]() + val in = TestPublisher.probe[Int]() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage( + probe.ref, + Some(asyncCb ⇒ earlyFeedback.completeWith(asyncCb.invokeWithFeedback("early"))) + ))(Keep.right) + .to(Sink.ignore) + .run() + + // and deliver in order + val laterFeedbackF = callback.invokeWithFeedback("later") + + probe.expectMsg(Started) + probe.expectMsg("early") + earlyFeedback.future.futureValue should ===(Done) + + probe.expectMsg("later") + laterFeedbackF.futureValue should ===(Done) + + in.sendComplete() + probe.expectMsg(Stopped) + } + + "accept concurrent input" in { + val probe = TestProbe() + val in = TestPublisher.probe[Int]() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.ignore) + .run() + + import system.dispatcher + val feedbacks = (1 to 100).map { n ⇒ + Future { + callback.invokeWithFeedback(n.toString) + }.flatMap(d ⇒ d) + } + + probe.expectMsg(Started) + Future.sequence(feedbacks).futureValue should have size (100) + (1 to 100).map(_ ⇒ probe.expectMsgType[String]).toSet should have size (100) + + in.sendComplete() + probe.expectMsg(Stopped) + } + + "fail the feedback if the handler throws" in { + val probe = TestProbe() + val in = TestPublisher.probe() + val callback = Source.fromPublisher(in) + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.ignore) + .run() + + probe.expectMsg(Started) + callback.invokeWithFeedback("happy-case").futureValue should ===(Done) + probe.expectMsg("happy-case") + + val feedbackF = callback.invokeWithFeedback(TE("oh my gosh, whale of a wash!")) + val failure = feedbackF.failed.futureValue + failure shouldBe a[TE] + failure.getMessage should ===("oh my gosh, whale of a wash!") + + in.expectCancellation() + } + + "fail the feedback if the handler fails the stage" in { + val probe = TestProbe() + val callback = Source.empty + .viaMat(new AsyncCallbackGraphStage(probe.ref))(Keep.right) + .to(Sink.ignore) + .run() + + probe.expectMsg(Started) + probe.expectMsg(Stopped) + + val feedbakF = callback.invokeWithFeedback("fail-the-stage") + val failure = feedbakF.failed.futureValue + failure shouldBe a[StreamDetachedException] // we can't capture the exception in this case + } + + } +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala index c83fdb725f..14de8732d1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala @@ -11,6 +11,7 @@ import akka.stream.impl.fusing.GraphInterpreter.{ Connection, DownstreamBoundary import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler, _ } import akka.stream.testkit.StreamSpec import akka.stream.testkit.Utils.TE +import akka.util.OptionVal import scala.collection.{ Map ⇒ SMap } import scala.language.existentials @@ -252,7 +253,7 @@ trait GraphInterpreterSpecKit extends StreamSpec { logger, logics, connections, - onAsyncInput = (_, _, _) ⇒ (), + onAsyncInput = (_, _, _, _) ⇒ (), fuzzingMode = false, context = null) _interpreter.init(null) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala index 0df497012e..6f66c9f9ca 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala @@ -365,6 +365,34 @@ class HubSpec extends StreamSpec { } } + "handle cancelled Sink" in assertAllStagesStopped { + val in = TestPublisher.probe[Int]() + val hubSource = Source.fromPublisher(in).runWith(BroadcastHub.sink(4)) + + val out = TestSubscriber.probe[Int]() + + hubSource.runWith(Sink.cancelled) + hubSource.runWith(Sink.fromSubscriber(out)) + + out.ensureSubscription() + + out.request(10) + in.expectRequest() + in.sendNext(1) + out.expectNext(1) + in.sendNext(2) + out.expectNext(2) + in.sendNext(3) + out.expectNext(3) + in.sendNext(4) + out.expectNext(4) + in.sendNext(5) + out.expectNext(5) + + in.sendComplete() + out.expectComplete() + } + } "PartitionHub" must { diff --git a/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes index 07e820af4b..fe984a3f4a 100644 --- a/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes +++ b/akka-stream/src/main/mima-filters/2.5.6.backwards.excludes @@ -21,3 +21,13 @@ ProblemFilters.exclude[MissingClassProblem]("akka.stream.stage.CallbackWrapper") # Optimize TCP stream writes ProblemFilters.exclude[Problem]("akka.stream.impl.io.*") + +# #23953 fixes to async callback with feedback +ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.fusing.GraphInterpreterShell$AsyncInput$") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.GraphInterpreterShell#AsyncInput.copy") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.stream.impl.fusing.GraphInterpreter.onAsyncInput") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.fusing.GraphInterpreter.this") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.GraphInterpreterShell#AsyncInput.apply") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.stream.impl.fusing.GraphInterpreterShell#AsyncInput.copy$default$4") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.GraphInterpreterShell#AsyncInput.this") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.GraphInterpreter.runAsyncInput") \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/StreamDetachedException.scala b/akka-stream/src/main/scala/akka/stream/StreamDetachedException.scala index 73888b1882..0fb53554e7 100644 --- a/akka-stream/src/main/scala/akka/stream/StreamDetachedException.scala +++ b/akka-stream/src/main/scala/akka/stream/StreamDetachedException.scala @@ -9,6 +9,10 @@ import scala.util.control.NoStackTrace * This exception signals that materialized value is already detached from stream. This usually happens * when stream is completed and an ActorSystem is shut down while materialized object is still available. */ -final class StreamDetachedException - extends RuntimeException("Stream is terminated. Materialized value is detached.") - with NoStackTrace +final class StreamDetachedException(message: String) + extends RuntimeException(message) + with NoStackTrace { + + def this() = this("Stream is terminated. Materialized value is detached.") + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 620b37bab5..9f235c3b1c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -7,6 +7,7 @@ import java.util import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicReference +import akka.Done import akka.actor._ import akka.annotation.InternalApi import akka.event.Logging @@ -20,6 +21,7 @@ import org.reactivestreams.{ Publisher, Subscriber, Subscription } import scala.annotation.tailrec import scala.collection.immutable +import scala.concurrent.Promise import scala.util.control.NonFatal /** @@ -447,15 +449,26 @@ import scala.util.control.NonFatal private var self: ActorRef = _ lazy val log = Logging(mat.system.eventStream, self) - final case class AsyncInput(shell: GraphInterpreterShell, logic: GraphStageLogic, evt: Any, handler: (Any) ⇒ Unit) extends BoundaryEvent { + /** + * @param promise Will be completed upon processing the event, or failed if processing the event throws + * if the event isn't ever processed the promise (the stage stops) is failed elsewhere + */ + final case class AsyncInput( + shell: GraphInterpreterShell, + logic: GraphStageLogic, + evt: Any, + promise: OptionVal[Promise[Done]], + handler: (Any) ⇒ Unit) extends BoundaryEvent { override def execute(eventLimit: Int): Int = { if (!waitingForShutdown) { - interpreter.runAsyncInput(logic, evt, handler) + interpreter.runAsyncInput(logic, evt, promise, handler) if (eventLimit == 1 && interpreter.isSuspended) { sendResume(true) 0 } else runBatch(eventLimit - 1) - } else eventLimit + } else { + eventLimit + } } } @@ -481,8 +494,8 @@ import scala.util.control.NonFatal private var enqueueToShortCircuit: (Any) ⇒ Unit = _ lazy val interpreter: GraphInterpreter = new GraphInterpreter(mat, log, logics, connections, - (logic, event, handler) ⇒ { - val asyncInput = AsyncInput(this, logic, event, handler) + (logic, event, promise, handler) ⇒ { + val asyncInput = AsyncInput(this, logic, event, promise, handler) val currentInterpreter = GraphInterpreter.currentInterpreterOrNull if (currentInterpreter == null || (currentInterpreter.context ne self)) self ! asyncInput diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index 442219591a..f0fa43f950 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -9,8 +9,11 @@ import akka.stream.stage._ import akka.stream._ import java.util.concurrent.ThreadLocalRandom +import akka.Done import akka.annotation.InternalApi +import akka.util.OptionVal +import scala.concurrent.Promise import scala.util.control.NonFatal /** @@ -191,7 +194,7 @@ import scala.util.control.NonFatal val log: LoggingAdapter, val logics: Array[GraphStageLogic], // Array of stage logics val connections: Array[GraphInterpreter.Connection], - val onAsyncInput: (GraphStageLogic, Any, (Any) ⇒ Unit) ⇒ Unit, + val onAsyncInput: (GraphStageLogic, Any, OptionVal[Promise[Done]], (Any) ⇒ Unit) ⇒ Unit, val fuzzingMode: Boolean, val context: ActorRef) { @@ -432,7 +435,7 @@ import scala.util.control.NonFatal eventsRemaining } - def runAsyncInput(logic: GraphStageLogic, evt: Any, handler: (Any) ⇒ Unit): Unit = + def runAsyncInput(logic: GraphStageLogic, evt: Any, promise: OptionVal[Promise[Done]], handler: (Any) ⇒ Unit): Unit = if (!isStageCompleted(logic)) { if (GraphInterpreter.Debug) println(s"$Name ASYNC $evt ($handler) [$logic]") val currentInterpreterHolder = _currentInterpreter.get() @@ -440,9 +443,13 @@ import scala.util.control.NonFatal currentInterpreterHolder(0) = this try { activeStage = logic - try handler(evt) - catch { - case NonFatal(ex) ⇒ logic.failStage(ex) + try { + handler(evt) + if (promise.isDefined) promise.get.success(Done) + } catch { + case NonFatal(ex) ⇒ + if (promise.isDefined) promise.get.failure(ex) + logic.failStage(ex) } afterStageHasRun(logic) } finally currentInterpreterHolder(0) = previousInterpreter diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala index 96edf32dc2..6a93a396e6 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala @@ -7,14 +7,13 @@ import java.util import java.util.concurrent.atomic.{ AtomicLong, AtomicReference } import akka.NotUsed -import akka.dispatch.AbstractNodeQueue +import akka.dispatch.{ AbstractNodeQueue, ExecutionContexts } import akka.stream._ import akka.stream.stage._ import scala.annotation.tailrec import scala.concurrent.{ Future, Promise } import scala.util.{ Failure, Success, Try } -import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicReferenceArray @@ -424,12 +423,19 @@ private[akka] class BroadcastHub[T](bufferSize: Int) extends GraphStageWithMater val startFrom = head activeConsumers += 1 addConsumer(consumer, startFrom) - consumer.callback.invoke(Initialize(startFrom)) + // in case the consumer is already stopped we need to undo registration + implicit val ec = materializer.executionContext + consumer.callback.invokeWithFeedback(Initialize(startFrom)).onFailure { + case _: StreamDetachedException ⇒ + callbackPromise.future.foreach(callback ⇒ + callback.invoke(UnRegister(consumer.id, startFrom, startFrom)) + ) + } } case UnRegister(id, previousOffset, finalOffset) ⇒ - activeConsumers -= 1 - val consumer = findAndRemoveConsumer(id, previousOffset) + if (findAndRemoveConsumer(id, previousOffset) != null) + activeConsumers -= 1 if (activeConsumers == 0) { if (isClosed(in)) completeStage() else if (head != finalOffset) { @@ -443,14 +449,15 @@ private[akka] class BroadcastHub[T](bufferSize: Int) extends GraphStageWithMater if (!hasBeenPulled(in)) pull(in) } } else checkUnblock(previousOffset) + case Advance(id, previousOffset) ⇒ val newOffset = previousOffset + DemandThreshold - // Move the consumer from its last known offest to its new one. Check if we are unblocked. + // Move the consumer from its last known offset to its new one. Check if we are unblocked. val consumer = findAndRemoveConsumer(id, previousOffset) addConsumer(consumer, newOffset) checkUnblock(previousOffset) case NeedWakeup(id, previousOffset, currentOffset) ⇒ - // Move the consumer from its last known offest to its new one. Check if we are unblocked. + // Move the consumer from its last known offset to its new one. Check if we are unblocked. val consumer = findAndRemoveConsumer(id, previousOffset) addConsumer(consumer, currentOffset) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 1e2aaa4909..bda38d6eef 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -1061,8 +1061,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: @tailrec final private[stage] def onStart(): Unit = { (currentState.getAndSet(Initializing): @unchecked) match { - case Pending(l) ⇒ l.reverse.foreach(ack ⇒ { - onAsyncInput(ack.e) + case Pending(l) ⇒ l.reverse.foreach(evt ⇒ { + onAsyncInput(evt.e, evt.handlingPromise) }) } if (!currentState.compareAndSet(Initializing, Initialized)) { @@ -1082,10 +1082,12 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: waitingForProcessing.clear() } - private def onAsyncInput(event: T) = interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) + private def onAsyncInput(event: T, promise: OptionVal[Promise[Done]]) = { + interpreter.onAsyncInput(GraphStageLogic.this, event, promise, handler.asInstanceOf[Any ⇒ Unit]) + } private def sendEvent(event: T, promise: Promise[Done]): Promise[Done] = { - onAsyncInput(event) + onAsyncInput(event, OptionVal.Some(promise)) currentState.get() match { case Completed ⇒ failPromiseOnComplete(promise) case _ ⇒ promise @@ -1123,7 +1125,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: private def failPromiseOnComplete(promise: Promise[Done]): Promise[Done] = { waitingForProcessing.remove(promise) - promise.tryFailure(new StreamDetachedException()) + promise.tryFailure(new StreamDetachedException("Stage stopped before async invocation was processed")) promise } @@ -1132,14 +1134,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: @tailrec def internalInvoke(event: T): Unit = currentState.get() match { // started - can just send message to stream - case Initialized ⇒ onAsyncInput(event) + case Initialized ⇒ onAsyncInput(event, OptionVal.None) // not started yet case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) // initializing is in progress in another thread (initializing thread is managed by akka) case Initializing ⇒ if (!currentState.compareAndSet(Initializing, Pending(Event(event, OptionVal.None) :: Nil))) { (currentState.get(): @unchecked) match { case list @ Pending(l) ⇒ if (!currentState.compareAndSet(list, Pending(Event(event, OptionVal.None) :: l))) internalInvoke(event) - case Initialized ⇒ onAsyncInput(event) + case Initialized ⇒ onAsyncInput(event, OptionVal.None) } } case Completed ⇒ // do nothing here as stream is completed @@ -1401,13 +1403,22 @@ trait AsyncCallback[T] { /** * Dispatch an asynchronous notification. This method is thread-safe and * may be invoked from external execution contexts. + * + * For cases where it is important to know if the notification was ever processed or not + * see [AsyncCallback#invokeWithFeedback]] */ def invoke(t: T): Unit /** - * Dispatch an asynchronous notification. - * This method is thread-safe and may be invoked from external execution contexts. - * Promise in `HasCallbackPromise` will fail if stream is already closed or closed before - * being able to process the event + * Dispatch an asynchronous notification. This method is thread-safe and + * may be invoked from external execution contexts. + * + * The method returns directly and the returned future is then completed once the event + * has been handled by the stage, if the event triggers an exception from the handler the future + * is failed with that exception and finally if the stage was stopped before the event has been + * handled the future is failed with `StreamDetachedException`. + * + * The handling of the returned future incurs a slight overhead, so for cases where it does not matter + * to the invoking logic see [[AsyncCallback#invoke]] */ def invokeWithFeedback(t: T): Future[Done] }