diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index 06a0060251..0bc6c15cb5 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -121,10 +121,7 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set.empty) source2.onNext("Meaning of life") - lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")))) - - sink.requestOne() - lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")), RequestOne(source1), RequestOne(source2))) } "implement Broadcast" in new TestSetup { @@ -169,13 +166,11 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set(RequestOne(source))) source.onNext(1) - lastEvents() should ===(Set(OnNext(sink, (1, 1)))) + lastEvents() should ===(Set(OnNext(sink, (1, 1)), RequestOne(source))) sink.requestOne() - lastEvents() should ===(Set(RequestOne(source))) - source.onNext(2) - lastEvents() should ===(Set(OnNext(sink, (2, 2)))) + lastEvents() should ===(Set(OnNext(sink, (2, 2)), RequestOne(source))) } @@ -198,16 +193,15 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set.empty) sink1.requestOne() - lastEvents() should ===(Set.empty) + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) sink2.requestOne() - lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) source1.onNext(1) lastEvents() should ===(Set.empty) source2.onNext(2) - lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)))) + lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)), RequestOne(source1), RequestOne(source2))) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala index 061826c017..51c684dcc8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala @@ -3,13 +3,18 @@ */ package akka.stream.scaladsl -import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings } +import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings, OverflowStrategy } +import akka.stream.impl.fusing.GraphStages.Detacher import akka.stream.testkit._ +import akka.stream.testkit.Utils._ +import akka.stream.testkit.scaladsl._ import com.typesafe.config.ConfigFactory +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time._ + +import scala.collection.immutable import scala.concurrent.Await import scala.concurrent.duration._ -import akka.stream.OverflowStrategy -import org.scalatest.concurrent.ScalaFutures class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INFO")) with ScalaFutures { @@ -18,8 +23,11 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF implicit val materializer = ActorMaterializer(settings) + implicit val defaultPatience = + PatienceConfig(timeout = Span(2, Seconds), interval = Span(200, Millis)) + "A Flow using join" must { - "allow for cycles" in { + "allow for cycles" in assertAllStagesStopped { val end = 47 val (even, odd) = (0 to end).partition(_ % 2 == 0) val result = Set() ++ even ++ odd ++ odd.map(_ * 10) @@ -51,14 +59,103 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF sub.cancel() } - "propagate one element" in { + "allow for merge cycle" in assertAllStagesStopped { val source = Source.single("lonely traveler") val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b ⇒ sink ⇒ import GraphDSL.Implicits._ val merge = b.add(Merge[String](2)) - val broadcast = b.add(Broadcast[String](2)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) + source ~> merge.in(0) + merge.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(merge.in(1), broadcast.out(1)) + }) + + whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler") + } + + "allow for merge preferred cycle" in assertAllStagesStopped { + val source = Source.single("lonely traveler") + + val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b ⇒ + sink ⇒ + import GraphDSL.Implicits._ + val merge = b.add(MergePreferred[String](1)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) + source ~> merge.preferred + merge.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(merge.in(0), broadcast.out(1)) + }) + + whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler") + } + + "allow for zip cycle" in assertAllStagesStopped { + val source = Source(immutable.Seq("traveler1", "traveler2")) + + val flow = Flow.fromGraph(GraphDSL.create(TestSink.probe[(String, String)]) { implicit b ⇒ + sink ⇒ + import GraphDSL.Implicits._ + val zip = b.add(Zip[String, String]) + val broadcast = b.add(Broadcast[(String, String)](2)) + source ~> zip.in0 + zip.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(zip.in1, broadcast.out(1)) + }) + + val feedback = Flow.fromGraph(GraphDSL.create(Source.single("ignition")) { implicit b ⇒ + ignition ⇒ + import GraphDSL.Implicits._ + val flow = b.add(Flow[(String, String)].map(_._1)) + val merge = b.add(Merge[String](2)) + + ignition ~> merge.in(0) + flow ~> merge.in(1) + + FlowShape(flow.in, merge.out) + }) + + val probe = flow.join(feedback).run() + probe.requestNext(("traveler1", "ignition")) + probe.requestNext(("traveler2", "traveler1")) + } + + "allow for concat cycle" in assertAllStagesStopped { + val flow = Flow.fromGraph(GraphDSL.create(TestSource.probe[String](system), Sink.head[String])(Keep.both) { implicit b ⇒ + (source, sink) ⇒ + import GraphDSL.Implicits._ + val concat = b.add(Concat[String](2)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) + source ~> concat.in(0) + concat.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(concat.in(1), broadcast.out(1)) + }) + + val (probe, result) = flow.join(Flow[String]).run() + probe.sendNext("lonely traveler") + whenReady(result) { r ⇒ + r shouldBe "lonely traveler" + probe.sendComplete() + } + } + + "allow for interleave cycle" in assertAllStagesStopped { + val source = Source.single("lonely traveler") + + val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b ⇒ + sink ⇒ + import GraphDSL.Implicits._ + val merge = b.add(Interleave[String](2, 1)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) source ~> merge.in(0) merge.out ~> broadcast.in broadcast.out(0) ~> sink diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala index e4987e386c..0b4108984d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala @@ -129,9 +129,6 @@ class GraphZipSpec extends TwoStreamsSetup { downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } "complete if one side complete before requested with elements pending" in { @@ -159,9 +156,6 @@ class GraphZipSpec extends TwoStreamsSetup { downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } "complete if one side complete before requested with elements pending 2" in { @@ -190,9 +184,6 @@ class GraphZipSpec extends TwoStreamsSetup { upstream2.sendComplete() downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } commonTests() diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template index 4afdec4342..4cad1f31a6 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template @@ -31,13 +31,22 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape ] override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var pending = 1 + var pending = ##0 // Without this field the completion signalling would take one extra pull var willShutDown = false private def pushAll(): Unit = { push(out, zipper([#grab(in0)#])) if (willShutDown) completeStage() + else { + [#pull(in0)# + ] + } + } + + override def preStart(): Unit = { + [#pull(in0)# + ] } [#setHandler(in0, new InHandler { @@ -56,17 +65,13 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape setHandler(out, new OutHandler { override def onPull(): Unit = { - pending = shape.inlets.size - if (willShutDown) completeStage() - else { - [#pull(in0)# - ] - } + pending += shape.inlets.size + if (pending == ##0) pushAll() } }) } - override def toString = "Zip" + override def toString = "ZipWith1" } # diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index ea70aaf273..53d67b645c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -8,6 +8,7 @@ import akka.actor.Cancellable import akka.dispatch.ExecutionContexts import akka.event.Logging import akka.stream._ +import akka.stream.scaladsl._ import akka.stream.impl.Stages.DefaultAttributes import akka.stream.stage._ import scala.concurrent.{ Future, Promise } @@ -65,14 +66,16 @@ object GraphStages { def identity[T] = Identity.asInstanceOf[SimpleLinearGraphStage[T]] - private class Detacher[T] extends GraphStage[FlowShape[T, T]] { + /** + * INERNAL API + */ + private[stream] final class Detacher[T] extends GraphStage[FlowShape[T, T]] { val in = Inlet[T]("in") val out = Outlet[T]("out") override def initialAttributes = Attributes.name("Detacher") override val shape = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var initialized = false setHandler(in, new InHandler { override def onPush(): Unit = { @@ -220,4 +223,27 @@ object GraphStages { } override def toString: String = "FutureSource" } + + /** + * INTERNAL API. + * + * Fusing graphs that have cycles involving FanIn stages might lead to deadlocks if + * demand is not carefully managed. + * + * This means that FanIn stages need to early pull every relevant input on startup. + * This can either be implemented inside the stage itself, or this method can be used, + * which adds a detacher stage to every input. + */ + private[stream] def withDetachedInputs[T](stage: GraphStage[UniformFanInShape[T, T]]) = + GraphDSL.create() { implicit builder ⇒ + import GraphDSL.Implicits._ + val concat = builder.add(stage) + val ds = concat.inSeq.map { inlet ⇒ + val detacher = builder.add(GraphStages.detacher[T]) + detacher ~> inlet + detacher.in + } + UniformFanInShape(concat.out, ds: _*) + } + } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 71767d94fb..d5ec09c67f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -3,15 +3,16 @@ */ package akka.stream.scaladsl -import akka.stream.impl.Stages.{ StageModule, SymbolicStage } -import akka.stream.impl._ -import akka.stream.impl.StreamLayout._ import akka.stream._ +import akka.stream.impl._ +import akka.stream.impl.fusing.GraphStages +import akka.stream.impl.fusing.GraphStages.MaterializedValueSource +import akka.stream.impl.Stages.{ StageModule, SymbolicStage } +import akka.stream.impl.StreamLayout._ import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } import scala.annotation.unchecked.uncheckedVariance import scala.annotation.tailrec import scala.collection.immutable -import akka.stream.impl.fusing.GraphStages.MaterializedValueSource object Merge { /** @@ -159,16 +160,12 @@ final class MergePreferred[T] private (val secondaryPorts: Int, val eagerComplet if (eagerComplete || openInputs == 0) completeStage() } - setHandler(out, new OutHandler { - private var first = true - override def onPull(): Unit = { - if (first) { - first = false - tryPull(preferred) - shape.inSeq.foreach(tryPull) - } - } - }) + override def preStart(): Unit = { + tryPull(preferred) + shape.inSeq.foreach(tryPull) + } + + setHandler(out, eagerTerminateOutput) val pullMe = Array.tabulate(secondaryPorts)(i ⇒ { val port = in(i) @@ -240,8 +237,8 @@ object Interleave { * @param segmentSize number of elements to send downstream before switching to next input port * @param eagerClose if true, interleave completes upstream if any of its upstream completes. */ - def apply[T](inputPorts: Int, segmentSize: Int, eagerClose: Boolean = false): Interleave[T] = - new Interleave(inputPorts, segmentSize, eagerClose) + def apply[T](inputPorts: Int, segmentSize: Int, eagerClose: Boolean = false): Graph[UniformFanInShape[T, T], Unit] = + GraphStages.withDetachedInputs(new Interleave[T](inputPorts, segmentSize, eagerClose)) } /** @@ -644,7 +641,8 @@ object Concat { /** * Create a new `Concat`. */ - def apply[T](inputPorts: Int = 2): Concat[T] = new Concat(inputPorts) + def apply[T](inputPorts: Int = 2): Graph[UniformFanInShape[T, T], Unit] = + GraphStages.withDetachedInputs(new Concat[T](inputPorts)) } /**