diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlowSpec.scala index 08245999d3..173381081f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlowSpec.scala @@ -155,6 +155,23 @@ class GraphFlowSpec extends AkkaSpec { validateProbe(probe, stdRequests, stdResult) } + "work with a Sink when having KeyedSource inside" in { + val out = UndefinedSink[Int] + val probe = StreamTestKit.SubscriberProbe[Int]() + val subSource = Source.subscriber[Int] + + val source = Source[Int]() { implicit b ⇒ + import FlowGraphImplicits._ + subSource ~> out + out + } + + val mm = source.to(Sink(probe)).run() + source1.to(Sink(mm.get(subSource))).run() + + validateProbe(probe, 4, (0 to 3).toSet) + } + "be transformable with a Pipe" in { val out = UndefinedSink[String] @@ -240,6 +257,23 @@ class GraphFlowSpec extends AkkaSpec { validateProbe(probe, stdRequests, stdResult) } + "work with a Source when having KeyedSink inside" in { + val in = UndefinedSource[Int] + val probe = StreamTestKit.SubscriberProbe[Int]() + val pubSink = Sink.publisher[Int] + + val sink = Sink[Int]() { implicit b ⇒ + import FlowGraphImplicits._ + in ~> pubSink + in + } + + val mm = source1.to(sink).run() + Source(mm.get(pubSink)).to(Sink(probe)).run() + + validateProbe(probe, 4, (0 to 3).toSet) + } + "be transformable with a Pipe" in { val in = UndefinedSource[String] diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala index 811460c940..00d6a66a17 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/FlowGraph.scala @@ -780,6 +780,18 @@ class FlowGraphBuilder private[akka] (_graph: DirectedGraphBuilder[FlowGraphInte uncheckedAddGraphEdge(from, to, pipe, inputPort, outputPort) } + private def addOrReplaceSinkEdge[In, Out](from: Vertex, to: Vertex, pipe: Pipe[In, Out], inputPort: Int, outputPort: Int): Unit = { + checkAddOrReplaceSourceSinkPrecondition(from) + checkAddSourceSinkPrecondition(to) + uncheckedAddGraphEdge(from, to, pipe, inputPort, outputPort) + } + + private def addOrReplaceSourceEdge[In, Out](from: Vertex, to: Vertex, pipe: Pipe[In, Out], inputPort: Int, outputPort: Int): Unit = { + checkAddSourceSinkPrecondition(from) + checkAddOrReplaceSourceSinkPrecondition(to) + uncheckedAddGraphEdge(from, to, pipe, inputPort, outputPort) + } + def attachSink[Out](token: UndefinedSink[Out], sink: Sink[Out]): this.type = { graph.find(token) match { case Some(existing) ⇒ @@ -788,11 +800,11 @@ class FlowGraphBuilder private[akka] (_graph: DirectedGraphBuilder[FlowGraphInte sink match { case spipe: SinkPipe[Out] ⇒ val pipe = edge.label.pipe.appendPipe(Pipe(spipe.ops)) - addGraphEdge(edge.from.label, SinkVertex(spipe.output), pipe, edge.label.inputPort, edge.label.outputPort) + addOrReplaceSinkEdge(edge.from.label, SinkVertex(spipe.output), pipe, edge.label.inputPort, edge.label.outputPort) case gsink: GraphSink[Out, _] ⇒ gsink.importAndConnect(this, token) case sink: Sink[Out] ⇒ - addGraphEdge(edge.from.label, SinkVertex(sink), edge.label.pipe, edge.label.inputPort, edge.label.outputPort) + addOrReplaceSinkEdge(edge.from.label, SinkVertex(sink), edge.label.pipe, edge.label.inputPort, edge.label.outputPort) } case None ⇒ throw new IllegalArgumentException(s"No matching UndefinedSink [${token}]") @@ -808,11 +820,11 @@ class FlowGraphBuilder private[akka] (_graph: DirectedGraphBuilder[FlowGraphInte source match { case spipe: SourcePipe[In] ⇒ val pipe = Pipe(spipe.ops).appendPipe(edge.label.pipe) - addGraphEdge(SourceVertex(spipe.input), edge.to.label, pipe, edge.label.inputPort, edge.label.outputPort) + addOrReplaceSourceEdge(SourceVertex(spipe.input), edge.to.label, pipe, edge.label.inputPort, edge.label.outputPort) case gsource: GraphSource[_, In] ⇒ gsource.importAndConnect(this, token) case source: Source[In] ⇒ - addGraphEdge(SourceVertex(source), edge.to.label, edge.label.pipe, edge.label.inputPort, edge.label.outputPort) + addOrReplaceSourceEdge(SourceVertex(source), edge.to.label, edge.label.pipe, edge.label.inputPort, edge.label.outputPort) case x ⇒ throwUnsupportedValue(x) }