diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/ActorRefBackpressureSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/ActorRefBackpressureSinkSpec.scala new file mode 100644 index 0000000000..0df9f951d7 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/ActorRefBackpressureSinkSpec.scala @@ -0,0 +1,115 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.actor.{ Actor, ActorRef, Props } +import akka.stream.ActorMaterializer +import akka.stream.testkit.Utils._ +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl._ + +object ActorRefBackpressureSinkSpec { + val initMessage = "start" + val completeMessage = "done" + val ackMessage = "ack" + + class Fw(ref: ActorRef) extends Actor { + def receive = { + case `initMessage` ⇒ + sender() ! ackMessage + ref forward initMessage + case `completeMessage` ⇒ + ref forward completeMessage + case msg: Int ⇒ + sender() ! ackMessage + ref forward msg + } + } + + case object TriggerAckMessage + + class Fw2(ref: ActorRef) extends Actor { + var actorRef: ActorRef = Actor.noSender + + def receive = { + case TriggerAckMessage ⇒ + actorRef ! ackMessage + case msg ⇒ + actorRef = sender() + ref forward msg + } + } + +} + +class ActorRefBackpressureSinkSpec extends AkkaSpec { + import ActorRefBackpressureSinkSpec._ + implicit val mat = ActorMaterializer() + + def createActor[T](c: Class[T]) = + system.actorOf(Props(c, testActor).withDispatcher("akka.test.stream-dispatcher")) + + "An ActorRefSink" must { + + "send the elements to the ActorRef" in assertAllStagesStopped { + val fw = createActor(classOf[Fw]) + Source(List(1, 2, 3)).runWith(Sink.actorRefWithAck(fw, + initMessage, ackMessage, completeMessage)) + expectMsg("start") + expectMsg(1) + expectMsg(2) + expectMsg(3) + expectMsg(completeMessage) + } + + "send the elements to the ActorRef2" in assertAllStagesStopped { + val fw = createActor(classOf[Fw]) + val probe = TestSource.probe[Int].to(Sink.actorRefWithAck(fw, + initMessage, ackMessage, completeMessage)).run() + probe.sendNext(1) + expectMsg("start") + expectMsg(1) + probe.sendNext(2) + expectMsg(2) + probe.sendNext(3) + expectMsg(3) + probe.sendComplete() + expectMsg(completeMessage) + } + + "cancel stream when actor terminates" in assertAllStagesStopped { + val fw = createActor(classOf[Fw]) + val publisher = TestSource.probe[Int].to(Sink.actorRefWithAck(fw, + initMessage, ackMessage, completeMessage)).run().sendNext(1) + expectMsg(initMessage) + expectMsg(1) + system.stop(fw) + publisher.expectCancellation() + } + + "send message only when backpressure received" in assertAllStagesStopped { + val fw = createActor(classOf[Fw2]) + val publisher = TestSource.probe[Int].to(Sink.actorRefWithAck(fw, + initMessage, ackMessage, completeMessage)).run() + expectMsg(initMessage) + + publisher.sendNext(1) + expectNoMsg() + fw ! TriggerAckMessage + expectMsg(1) + + publisher.sendNext(2) + publisher.sendNext(3) + publisher.sendComplete() + fw ! TriggerAckMessage + expectMsg(2) + fw ! TriggerAckMessage + expectMsg(3) + + expectMsg(completeMessage) + } + + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIntersperseSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIntersperseSpec.scala index 0f1360692e..fa25b1644f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIntersperseSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowIntersperseSpec.scala @@ -4,7 +4,8 @@ package akka.stream.scaladsl import akka.stream.testkit._ -import akka.stream.testkit.scaladsl.TestSink +import akka.stream.testkit.Utils.assertAllStagesStopped +import akka.stream.testkit.scaladsl.{ TestSource, TestSink } import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } import org.scalatest.concurrent.ScalaFutures @@ -18,7 +19,6 @@ class FlowIntersperseSpec extends AkkaSpec with ScalaFutures { implicit val materializer = ActorMaterializer(settings) "A Intersperse" must { - "inject element between existing elements" in { val probe = Source(List(1, 2, 3)) .map(_.toString) @@ -74,6 +74,29 @@ class FlowIntersperseSpec extends AkkaSpec with ScalaFutures { probe.expectSubscription() probe.toStrict(1.second).mkString("") should ===(List(1).mkString("[", ",", "]")) } + + "complete the stage when the Source has been completed" in { + val (p1, p2) = TestSource.probe[String].intersperse(",").toMat(TestSink.probe[String])(Keep.both).run + p2.request(10) + p1.sendNext("a") + .sendNext("b") + .sendComplete() + p2.expectNext("a") + .expectNext(",") + .expectNext("b") + .expectComplete() + } + + "complete the stage when the Sink has been cancelled" in { + val (p1, p2) = TestSource.probe[String].intersperse(",").toMat(TestSink.probe[String])(Keep.both).run + p2.request(10) + p1.sendNext("a") + .sendNext("b") + p2.expectNext("a") + .expectNext(",") + .cancel() + p1.expectCancellation() + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala new file mode 100644 index 0000000000..b75c97343e --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala @@ -0,0 +1,93 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl + +import java.util + +import akka.actor._ +import akka.dispatch.sysmsg.{ DeathWatchNotification, SystemMessage, Watch } +import akka.stream.stage.GraphStageLogic.StageActorRef +import akka.stream.{ Inlet, SinkShape, ActorMaterializer, Attributes } +import akka.stream.Attributes.InputBuffer +import akka.stream.stage._ + +/** + * INTERNAL API + */ +private[akka] class ActorRefBackpressureSinkStage[In](ref: ActorRef, onInitMessage: Any, + ackMessage: Any, + onCompleteMessage: Any, + onFailureMessage: (Throwable) ⇒ Any) + extends GraphStage[SinkShape[In]] { + val in: Inlet[In] = Inlet[In]("ActorRefBackpressureSink.in") + override val shape: SinkShape[In] = SinkShape(in) + + val maxBuffer = module.attributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max + require(maxBuffer > 0, "Buffer size must be greater than 0") + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + implicit var self: StageActorRef = _ + + val buffer: util.Deque[In] = new util.ArrayDeque[In]() + var acknowledgementReceived = false + var completeReceived = false + + override def keepGoingAfterAllPortsClosed: Boolean = true + + private val callback: AsyncCallback[Unit] = getAsyncCallback((_: Unit) ⇒ { + if (!buffer.isEmpty) sendData() + else acknowledgementReceived = true + }) + + private val deathWatchCallback: AsyncCallback[Unit] = + getAsyncCallback((Unit) ⇒ completeStage()) + + private def receive(evt: (ActorRef, Any)): Unit = { + evt._2 match { + case `ackMessage` ⇒ callback.invoke(()) + case Terminated(`ref`) ⇒ deathWatchCallback.invoke(()) + case _ ⇒ //ignore all other messages + } + } + + override def preStart() = { + self = getStageActorRef(receive) + self.watch(ref) + ref ! onInitMessage + pull(in) + } + + private def sendData(): Unit = { + if (!buffer.isEmpty) { + ref ! buffer.poll() + acknowledgementReceived = false + } + if (buffer.isEmpty && completeReceived) finish() + } + + private def finish(): Unit = { + ref ! onCompleteMessage + completeStage() + } + + setHandler(in, new InHandler { + override def onPush(): Unit = { + buffer offer grab(in) + if (acknowledgementReceived) sendData() + if (buffer.size() < maxBuffer) pull(in) + } + override def onUpstreamFinish(): Unit = { + if (buffer.isEmpty) finish() + else completeReceived = true + } + override def onUpstreamFailure(ex: Throwable): Unit = { + ref ! onFailureMessage(ex) + failStage(ex) + } + }) + } + + override def toString = "ActorRefBackpressureSink" +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index c95a2cb445..3e18d37177 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -188,10 +188,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) } - final case class Intersperse[T](start: Option[T], inject: T, end: Option[T], attributes: Attributes = intersperse) extends SymbolicStage[T, T] { - override def create(attr: Attributes) = fusing.Intersperse(start, inject, end) - } - final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index b2d8cfe8c6..8cb1e8650c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -240,43 +240,43 @@ private[akka] final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, de /** * INTERNAL API */ -private[akka] final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) extends StatefulStage[T, T] { - private var needsToEmitStart = start.isDefined +final case class Intersperse[T](start: Option[T], inject: T, end: Option[T]) extends GraphStage[FlowShape[T, T]] { - override def initial: StageState[T, T] = - start match { - case Some(initial) ⇒ firstWithInitial(initial) - case _ ⇒ first + private val in = Inlet[T]("in") + private val out = Outlet[T]("out") + + override val shape = FlowShape(in, out) + + override def createLogic(attr: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + val startInHandler = new InHandler { + override def onPush(): Unit = { + // if else (to avoid using Iterator[T].flatten in hot code) + if (start.isDefined) emitMultiple(out, Iterator(start.get, grab(in))) + else emit(out, grab(in)) + setHandler(in, restInHandler) // switch handler + } + + override def onUpstreamFinish(): Unit = { + emitMultiple(out, Iterator(start, end).flatten) + completeStage() + } } - def firstWithInitial(initial: T) = new StageState[T, T] { - override def onPush(elem: T, ctx: Context[T]) = { - needsToEmitStart = false - emit(Iterator(initial, elem), ctx, running) - } - } + val restInHandler = new InHandler { + override def onPush(): Unit = emitMultiple(out, Iterator(inject, grab(in))) - def first = new StageState[T, T] { - override def onPush(elem: T, ctx: Context[T]) = { - become(running) - ctx.push(elem) + override def onUpstreamFinish(): Unit = { + if (end.isDefined) emit(out, end.get) + completeStage() + } } - } - def running = new StageState[T, T] { - override def onPush(elem: T, ctx: Context[T]): SyncDirective = - emit(Iterator(inject, elem), ctx) - } - - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { - end match { - case Some(e) if needsToEmitStart ⇒ - terminationEmit(Iterator(start.get, end.get), ctx) - case Some(e) ⇒ - terminationEmit(Iterator(end.get), ctx) - case _ ⇒ - terminationEmit(Iterator(), ctx) + val outHandler = new OutHandler { + override def onPull(): Unit = pull(in) } + + setHandler(in, startInHandler) + setHandler(out, outHandler) } } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala index 0c40e382ac..b190a514a8 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala @@ -8,6 +8,7 @@ import java.io.{ InputStream, OutputStream, File } import akka.actor.{ ActorRef, Props } import akka.dispatch.ExecutionContexts import akka.japi.function +import akka.stream.impl.Stages.DefaultAttributes import akka.stream.impl.StreamLayout import akka.stream.{ javadsl, scaladsl, _ } import akka.util.ByteString @@ -153,6 +154,23 @@ object Sink { def actorRef[In](ref: ActorRef, onCompleteMessage: Any): Sink[In, Unit] = new Sink(scaladsl.Sink.actorRef[In](ref, onCompleteMessage)) + /** + * Sends the elements of the stream to the given `ActorRef` that sends back back-pressure signal. + * First element is always `onInitMessage`, then stream is waiting for acknowledgement message + * `ackMessage` from the given actor which means that it is ready to process + * elements. It also requires `ackMessage` message after each stream element + * to make backpressure work. + * + * If the target actor terminates the stream will be canceled. + * When the stream is completed successfully the given `onCompleteMessage` + * will be sent to the destination actor. + * When the stream is completed with failure - result of `onFailureMessage(throwable)` + * message will be sent to the destination actor. + */ + def actorRefWithAck[In](ref: ActorRef, onInitMessage: Any, ackMessage: Any, onCompleteMessage: Any, + onFailureMessage: function.Function[Throwable, Any]): Sink[In, Unit] = + new Sink(scaladsl.Sink.actorRefWithAck[In](ref, onInitMessage, ackMessage, onCompleteMessage, onFailureMessage.apply)) + /** * Creates a `Sink` that is materialized to an [[akka.actor.ActorRef]] which points to an Actor * created according to the passed in [[akka.actor.Props]]. Actor created by the `props` should diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 4e09889c3e..ef4b67bbb7 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -665,7 +665,7 @@ trait FlowOps[+Out, +Mat] { ReactiveStreamsCompliance.requireNonNullElement(start) ReactiveStreamsCompliance.requireNonNullElement(inject) ReactiveStreamsCompliance.requireNonNullElement(end) - andThen(Intersperse(Some(start), inject, Some(end))) + via(Intersperse(Some(start), inject, Some(end))) } /** @@ -692,7 +692,7 @@ trait FlowOps[+Out, +Mat] { */ def intersperse[T >: Out](inject: T): Repr[T] = { ReactiveStreamsCompliance.requireNonNullElement(inject) - andThen(Intersperse(None, inject, None)) + via(Intersperse(None, inject, None)) } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala index 799db83ecc..03e98123c3 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala @@ -4,9 +4,8 @@ package akka.stream.scaladsl import java.io.{ InputStream, OutputStream, File } - -import akka.actor.{ ActorRef, Props } import akka.dispatch.ExecutionContexts +import akka.actor.{ Status, ActorRef, Props } import akka.stream.actor.ActorSubscriber import akka.stream.impl.Stages.DefaultAttributes import akka.stream.impl.StreamLayout.Module @@ -239,6 +238,23 @@ object Sink { def actorRef[T](ref: ActorRef, onCompleteMessage: Any): Sink[T, Unit] = new Sink(new ActorRefSink(ref, onCompleteMessage, DefaultAttributes.actorRefSink, shape("ActorRefSink"))) + /** + * Sends the elements of the stream to the given `ActorRef` that sends back back-pressure signal. + * First element is always `onInitMessage`, then stream is waiting for acknowledgement message + * `ackMessage` from the given actor which means that it is ready to process + * elements. It also requires `ackMessage` message after each stream element + * to make backpressure work. + * + * If the target actor terminates the stream will be canceled. + * When the stream is completed successfully the given `onCompleteMessage` + * will be sent to the destination actor. + * When the stream is completed with failure - result of `onFailureMessage(throwable)` + * function will be sent to the destination actor. + */ + def actorRefWithAck[T](ref: ActorRef, onInitMessage: Any, ackMessage: Any, onCompleteMessage: Any, + onFailureMessage: (Throwable) ⇒ Any = Status.Failure): Sink[T, Unit] = + Sink.fromGraph(new ActorRefBackpressureSinkStage(ref, onInitMessage, ackMessage, onCompleteMessage, onFailureMessage)) + /** * Creates a `Sink` that is materialized to an [[akka.actor.ActorRef]] which points to an Actor * created according to the passed in [[akka.actor.Props]]. Actor created by the `props` must diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 86b134f4ee..2925a62ba6 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -3,23 +3,21 @@ */ package akka.stream.stage -import java.util -import java.util.concurrent.atomic.{ AtomicReferenceFieldUpdater, AtomicReference } +import java.util.concurrent.atomic.AtomicReference import akka.actor._ -import akka.actor.dungeon.DeathWatch -import akka.dispatch.sysmsg.{ Unwatch, Watch, DeathWatchNotification, SystemMessage } -import akka.event.{ LoggingAdapter, Logging } -import akka.event.Logging.{ Warning, Debug } +import akka.dispatch.sysmsg.{ DeathWatchNotification, SystemMessage, Unwatch, Watch } +import akka.event.LoggingAdapter import akka.stream._ -import akka.stream.impl.{ SeqActorName, ActorMaterializerImpl, ReactiveStreamsCompliance } import akka.stream.impl.StreamLayout.Module -import akka.stream.impl.fusing.{ GraphModule, GraphInterpreter } import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly +import akka.stream.impl.fusing.{ GraphInterpreter, GraphModule } +import akka.stream.impl.{ ReactiveStreamsCompliance, SeqActorName } + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration -import scala.collection.mutable.ArrayBuffer -import scala.annotation.tailrec abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, M] {