diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index 06a0060251..0bc6c15cb5 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -121,10 +121,7 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set.empty) source2.onNext("Meaning of life") - lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")))) - - sink.requestOne() - lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")), RequestOne(source1), RequestOne(source2))) } "implement Broadcast" in new TestSetup { @@ -169,13 +166,11 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set(RequestOne(source))) source.onNext(1) - lastEvents() should ===(Set(OnNext(sink, (1, 1)))) + lastEvents() should ===(Set(OnNext(sink, (1, 1)), RequestOne(source))) sink.requestOne() - lastEvents() should ===(Set(RequestOne(source))) - source.onNext(2) - lastEvents() should ===(Set(OnNext(sink, (2, 2)))) + lastEvents() should ===(Set(OnNext(sink, (2, 2)), RequestOne(source))) } @@ -198,16 +193,15 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should ===(Set.empty) sink1.requestOne() - lastEvents() should ===(Set.empty) + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) sink2.requestOne() - lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) source1.onNext(1) lastEvents() should ===(Set.empty) source2.onNext(2) - lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)))) + lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)), RequestOne(source1), RequestOne(source2))) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala index 061826c017..f9e824431c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowJoinSpec.scala @@ -3,13 +3,18 @@ */ package akka.stream.scaladsl -import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings } +import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings, OverflowStrategy } +import akka.stream.impl.fusing.GraphStages.Detacher import akka.stream.testkit._ +import akka.stream.testkit.Utils._ +import akka.stream.testkit.scaladsl._ import com.typesafe.config.ConfigFactory +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time._ + +import scala.collection.immutable import scala.concurrent.Await import scala.concurrent.duration._ -import akka.stream.OverflowStrategy -import org.scalatest.concurrent.ScalaFutures class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INFO")) with ScalaFutures { @@ -18,8 +23,11 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF implicit val materializer = ActorMaterializer(settings) + implicit val defaultPatience = + PatienceConfig(timeout = Span(2, Seconds), interval = Span(200, Millis)) + "A Flow using join" must { - "allow for cycles" in { + "allow for cycles" in assertAllStagesStopped { val end = 47 val (even, odd) = (0 to end).partition(_ % 2 == 0) val result = Set() ++ even ++ odd ++ odd.map(_ * 10) @@ -51,14 +59,14 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF sub.cancel() } - "propagate one element" in { + "allow for merge cycle" in assertAllStagesStopped { val source = Source.single("lonely traveler") val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b ⇒ sink ⇒ import GraphDSL.Implicits._ val merge = b.add(Merge[String](2)) - val broadcast = b.add(Broadcast[String](2)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) source ~> merge.in(0) merge.out ~> broadcast.in broadcast.out(0) ~> sink @@ -68,5 +76,55 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler") } + + "allow for merge preferred cycle" in assertAllStagesStopped { + val source = Source.single("lonely traveler") + + val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b ⇒ + sink ⇒ + import GraphDSL.Implicits._ + val merge = b.add(MergePreferred[String](1)) + val broadcast = b.add(Broadcast[String](2, eagerCancel = true)) + source ~> merge.preferred + merge.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(merge.in(0), broadcast.out(1)) + }) + + whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler") + } + + "allow for zip cycle" in assertAllStagesStopped { + val source = Source(immutable.Seq("traveler1", "traveler2")) + + val flow = Flow.fromGraph(GraphDSL.create(TestSink.probe[(String, String)]) { implicit b ⇒ + sink ⇒ + import GraphDSL.Implicits._ + val zip = b.add(Zip[String, String]) + val broadcast = b.add(Broadcast[(String, String)](2)) + source ~> zip.in0 + zip.out ~> broadcast.in + broadcast.out(0) ~> sink + + FlowShape(zip.in1, broadcast.out(1)) + }) + + val feedback = Flow.fromGraph(GraphDSL.create(Source.single("ignition")) { implicit b ⇒ + ignition ⇒ + import GraphDSL.Implicits._ + val flow = b.add(Flow[(String, String)].map(_._1)) + val merge = b.add(Merge[String](2)) + + ignition ~> merge.in(0) + flow ~> merge.in(1) + + FlowShape(flow.in, merge.out) + }) + + val probe = flow.join(feedback).run() + probe.requestNext(("traveler1", "ignition")) + probe.requestNext(("traveler2", "traveler1")) + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala index e4987e386c..0b4108984d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphZipSpec.scala @@ -129,9 +129,6 @@ class GraphZipSpec extends TwoStreamsSetup { downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } "complete if one side complete before requested with elements pending" in { @@ -159,9 +156,6 @@ class GraphZipSpec extends TwoStreamsSetup { downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } "complete if one side complete before requested with elements pending 2" in { @@ -190,9 +184,6 @@ class GraphZipSpec extends TwoStreamsSetup { upstream2.sendComplete() downstream.requestNext((1, "A")) downstream.expectComplete() - - upstream1.expectNoMsg(500.millis) - upstream2.expectNoMsg(500.millis) } commonTests() diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template index 4afdec4342..4cad1f31a6 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template @@ -31,13 +31,22 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape ] override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var pending = 1 + var pending = ##0 // Without this field the completion signalling would take one extra pull var willShutDown = false private def pushAll(): Unit = { push(out, zipper([#grab(in0)#])) if (willShutDown) completeStage() + else { + [#pull(in0)# + ] + } + } + + override def preStart(): Unit = { + [#pull(in0)# + ] } [#setHandler(in0, new InHandler { @@ -56,17 +65,13 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape setHandler(out, new OutHandler { override def onPull(): Unit = { - pending = shape.inlets.size - if (willShutDown) completeStage() - else { - [#pull(in0)# - ] - } + pending += shape.inlets.size + if (pending == ##0) pushAll() } }) } - override def toString = "Zip" + override def toString = "ZipWith1" } # diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index ea70aaf273..bb0e4bf5e8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -65,14 +65,16 @@ object GraphStages { def identity[T] = Identity.asInstanceOf[SimpleLinearGraphStage[T]] - private class Detacher[T] extends GraphStage[FlowShape[T, T]] { + /** + * INERNAL API + */ + private[stream] final class Detacher[T] extends GraphStage[FlowShape[T, T]] { val in = Inlet[T]("in") val out = Outlet[T]("out") override def initialAttributes = Attributes.name("Detacher") override val shape = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var initialized = false setHandler(in, new InHandler { override def onPush(): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 71767d94fb..86a09fb05f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -159,16 +159,12 @@ final class MergePreferred[T] private (val secondaryPorts: Int, val eagerComplet if (eagerComplete || openInputs == 0) completeStage() } - setHandler(out, new OutHandler { - private var first = true - override def onPull(): Unit = { - if (first) { - first = false - tryPull(preferred) - shape.inSeq.foreach(tryPull) - } - } - }) + override def preStart(): Unit = { + tryPull(preferred) + shape.inSeq.foreach(tryPull) + } + + setHandler(out, eagerTerminateOutput) val pullMe = Array.tabulate(secondaryPorts)(i ⇒ { val port = in(i)