diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index 2c61840b43..a32ffad351 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -348,7 +348,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. } class TestSetup { - val requests = TestPublisher.manualProbe[HttpRequest] + val requests = TestPublisher.manualProbe[HttpRequest]() val responses = TestSubscriber.manualProbe[HttpResponse] val remoteAddress = new InetSocketAddress("example.com", 80) @@ -357,7 +357,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. val (netOut, netIn) = { val netOut = TestSubscriber.manualProbe[ByteString] - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() FlowGraph.closed(OutgoingConnectionBlueprint(remoteAddress, settings, NoLogging)) { implicit b ⇒ client ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index aa97aeebe9..2a65b2ac00 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -406,7 +406,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) @@ -429,7 +429,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) @@ -453,7 +453,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ChunkStreamPart] + val data = TestPublisher.manualProbe[ChunkStreamPart]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) @@ -477,7 +477,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Connection: close | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala index f901202d6e..9a1b2cb8a8 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala @@ -29,13 +29,13 @@ abstract class HttpServerTestSetupBase { implicit def materializer: FlowMaterializer val requests = TestSubscriber.manualProbe[HttpRequest] - val responses = TestPublisher.manualProbe[HttpResponse] + val responses = TestPublisher.manualProbe[HttpResponse]() def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) def remoteAddress: Option[InetSocketAddress] = None val (netIn, netOut) = { - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() val netOut = TestSubscriber.manualProbe[ByteString] FlowGraph.closed(HttpServerBluePrint(settings, remoteAddress = remoteAddress, log = NoLogging)) { implicit b ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 6f529fd7a3..67804272ce 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -222,7 +222,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a strict message larger than configured maximum frame size" in pending "for a streamed message" in new ServerTestSetup { val data = ByteString("abcdefg", "ASCII") - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) netOutSub.request(6) pushMessage(msg) @@ -245,7 +245,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a streamed message with a chunk being larger than configured maximum frame size" in pending "and mask input on the client side" in new ClientTestSetup { val data = ByteString("abcdefg", "ASCII") - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) netOutSub.request(7) pushMessage(msg) @@ -278,7 +278,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a strict message larger than configured maximum frame size" in pending "for a streamed message" in new ServerTestSetup { val text = "äbcd€fg" - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(6) pushMessage(msg) @@ -310,7 +310,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { println(half1(0).toInt.toHexString) println(half2(0).toInt.toHexString) - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(6) @@ -327,7 +327,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a streamed message with a chunk being larger than configured maximum frame size" in pending "and mask input on the client side" in new ClientTestSetup { val text = "abcdefg" - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(5) pushMessage(msg) @@ -382,7 +382,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { s.request(2) sub.expectNext(ByteString("123", "ASCII")) - val outPub = TestPublisher.manualProbe[ByteString] + val outPub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(outPub)) netOutSub.request(10) pushMessage(msg) @@ -459,7 +459,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { messageIn.expectComplete() // sending another message is allowed before closing (inherently racy) - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -504,7 +504,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { // sending another message is allowed before closing (inherently racy) - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -549,7 +549,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { messageInSub.request(10) // send half a message - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -765,11 +765,11 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { protected def serverSide: Boolean protected def closeTimeout: FiniteDuration = 1.second - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() val netOut = TestSubscriber.manualProbe[ByteString] val messageIn = TestSubscriber.manualProbe[Message] - val messageOut = TestPublisher.manualProbe[Message] + val messageOut = TestPublisher.manualProbe[Message]() val messageHandler: Flow[Message, Message, Unit] = Flow.wrap { diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index 4f488f5fda..bf512ba4f4 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -50,7 +50,7 @@ object TestPublisher { /** * Probe that implements [[org.reactivestreams.Publisher]] interface. */ - def manualProbe[T]()(implicit system: ActorSystem): ManualProbe[T] = new ManualProbe() + def manualProbe[T](autoOnSubscribe: Boolean = true)(implicit system: ActorSystem): ManualProbe[T] = new ManualProbe(autoOnSubscribe) /** * Probe that implements [[org.reactivestreams.Publisher]] interface and tracks demand. @@ -62,7 +62,7 @@ object TestPublisher { * This probe does not track demand. Therefore you need to expect demand before sending * elements downstream. */ - class ManualProbe[I] private[TestPublisher] ()(implicit system: ActorSystem) extends Publisher[I] { + class ManualProbe[I] private[TestPublisher] (autoOnSubscribe: Boolean = true)(implicit system: ActorSystem) extends Publisher[I] { type Self <: ManualProbe[I] @@ -76,7 +76,7 @@ object TestPublisher { def subscribe(subscriber: Subscriber[_ >: I]): Unit = { val subscription: PublisherProbeSubscription[I] = new PublisherProbeSubscription[I](subscriber, probe) probe.ref ! Subscribe(subscription) - subscriber.onSubscribe(subscription) + if (autoOnSubscribe) subscriber.onSubscribe(subscription) } /** @@ -396,6 +396,8 @@ private[testkit] object StreamTestKit { def sendNext(element: I): Unit = subscriber.onNext(element) def sendComplete(): Unit = subscriber.onComplete() def sendError(cause: Exception): Unit = subscriber.onError(cause) + + def sendOnSubscribe(): Unit = subscriber.onSubscribe(this) } final class ProbeSource[T](val attributes: OperationAttributes, shape: SourceShape[T])(implicit system: ActorSystem) extends SourceModule[T, TestPublisher.Probe[T]](shape) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala index 37e3e3ce29..7d3301a94a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala @@ -17,7 +17,7 @@ class ActorInterpreterLifecycleSpec extends AkkaSpec with InterpreterLifecycleSp implicit val mat = ActorFlowMaterializer() class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int] private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") val actor = system.actorOf(props) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala index eb228395ae..91c0f202d2 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala @@ -22,7 +22,7 @@ class ActorInterpreterSpec extends AkkaSpec { implicit val mat = ActorFlowMaterializer() class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int] private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") val actor = system.actorOf(props) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala index f7aaea5f5f..ba5cb29e9a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala @@ -71,6 +71,29 @@ class FlowConcatAllSpec extends AkkaSpec { subUpstream.expectCancellation() } + "on onError on master stream cancel the currently opening substream and signal error" in assertAllStagesStopped { + val publisher = TestPublisher.manualProbe[Source[Int, _]]() + val subscriber = TestSubscriber.manualProbe[Int]() + Source(publisher).flatten(FlattenStrategy.concat).to(Sink(subscriber)).run() + + val upstream = publisher.expectSubscription() + val downstream = subscriber.expectSubscription() + downstream.request(1000) + + val substreamPublisher = TestPublisher.manualProbe[Int](autoOnSubscribe = false) + val substreamFlow = Source(substreamPublisher) + upstream.expectRequest() + upstream.sendNext(substreamFlow) + val subUpstream = substreamPublisher.expectSubscription() + + upstream.sendError(testException) + + subUpstream.sendOnSubscribe() + + subscriber.expectError(testException) + subUpstream.expectCancellation() + } + "on onError on open substream, cancel the master stream and signal error " in assertAllStagesStopped { val publisher = TestPublisher.manualProbe[Source[Int, _]]() val subscriber = TestSubscriber.manualProbe[Int]() @@ -112,6 +135,29 @@ class FlowConcatAllSpec extends AkkaSpec { upstream.expectCancellation() } + "on cancellation cancel the currently opening substream and the master stream" in assertAllStagesStopped { + val publisher = TestPublisher.manualProbe[Source[Int, _]]() + val subscriber = TestSubscriber.manualProbe[Int]() + Source(publisher).flatten(FlattenStrategy.concat).to(Sink(subscriber)).run() + + val upstream = publisher.expectSubscription() + val downstream = subscriber.expectSubscription() + downstream.request(1000) + + val substreamPublisher = TestPublisher.manualProbe[Int](autoOnSubscribe = false) + val substreamFlow = Source(substreamPublisher) + upstream.expectRequest() + upstream.sendNext(substreamFlow) + val subUpstream = substreamPublisher.expectSubscription() + + downstream.cancel() + + subUpstream.sendOnSubscribe() + + subUpstream.expectCancellation() + upstream.expectCancellation() + } + "pass along early cancellation" in assertAllStagesStopped { val up = TestPublisher.manualProbe[Source[Int, _]]() val down = TestSubscriber.manualProbe[Int]() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala index dc190188a1..c65e6d9fe1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala @@ -258,8 +258,8 @@ class FlowGraphCompileSpec extends AkkaSpec { "build with implicits and variance" in { FlowGraph.closed() { implicit b ⇒ - def appleSource = b.add(Source(TestPublisher.manualProbe[Apple])) - def fruitSource = b.add(Source(TestPublisher.manualProbe[Fruit])) + def appleSource = b.add(Source(TestPublisher.manualProbe[Apple]())) + def fruitSource = b.add(Source(TestPublisher.manualProbe[Fruit]())) val outA = b add Sink(TestSubscriber.manualProbe[Fruit]()) val outB = b add Sink(TestSubscriber.manualProbe[Fruit]()) val merge = b add Merge[Fruit](11) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala index 6a7cfa3113..b13d91efea 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala @@ -418,7 +418,7 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug downstream.cancel() onDownstreamFinishProbe.expectMsg("onDownstreamFinish") - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() up.subscribe(s) val upsub = up.expectSubscription() upsub.expectCancellation() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala index b7437831f7..b7d6843413 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala @@ -570,7 +570,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "propagate failure" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -587,7 +587,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -607,7 +607,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure for user thrown exception" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -626,7 +626,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure for user thrown exception in onComplete" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -670,7 +670,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "support finish from onInput" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala index f2a90a8d5d..a4fd75a4b9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala @@ -145,8 +145,8 @@ class GraphMergeSpec extends TwoStreamsSetup { } "pass along early cancellation" in assertAllStagesStopped { - val up1 = TestPublisher.manualProbe[Int] - val up2 = TestPublisher.manualProbe[Int] + val up1 = TestPublisher.manualProbe[Int]() + val up2 = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int]() val src1 = Source.subscriber[Int] diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index 30167ff7b9..695479c4d1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -122,7 +122,7 @@ class SourceSpec extends AkkaSpec { "Composite Source" must { "merge from many inputs" in { - val probes = Seq.fill(5)(TestPublisher.manualProbe[Int]) + val probes = Seq.fill(5)(TestPublisher.manualProbe[Int]()) val source = Source.subscriber[Int] val out = TestSubscriber.manualProbe[Int] diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index 76650e838d..033d818048 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -11,7 +11,6 @@ import akka.pattern.ask import akka.stream.actor.ActorSubscriber import akka.stream.impl.GenJunctions.ZipWithModule import akka.stream.impl.Junctions._ -import akka.stream.impl.MultiStreamInputProcessor.SubstreamSubscriber import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.ActorInterpreter import akka.stream.impl.io.SslTlsCipherActor diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala index 1ee755fd25..ff65edb8d2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala @@ -297,7 +297,7 @@ private[akka] abstract class TwoStreamInputProcessor(_settings: ActorFlowMateria private[akka] object MultiStreamInputProcessor { case class SubstreamKey(id: Long) - class SubstreamSubscriber[T](val impl: ActorRef, key: SubstreamKey) extends Subscriber[T] { + class SubstreamSubscriber[T](val impl: ActorRef, key: SubstreamKey) extends AtomicReference[Subscription] with Subscriber[T] { override def onError(cause: Throwable): Unit = { ReactiveStreamsCompliance.requireNonNullException(cause) impl ! SubstreamOnError(key, cause) @@ -309,7 +309,8 @@ private[akka] object MultiStreamInputProcessor { } override def onSubscribe(subscription: Subscription): Unit = { ReactiveStreamsCompliance.requireNonNullSubscription(subscription) - impl ! SubstreamStreamOnSubscribe(key, subscription) + if (compareAndSet(null, subscription)) impl ! SubstreamStreamOnSubscribe(key, subscription) + else subscription.cancel() } } @@ -346,16 +347,19 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ protected def inputBufferSize: Int private val substreamInputs = collection.mutable.Map.empty[SubstreamKey, SubstreamInput] + private val waitingForOnSubscribe = collection.mutable.Map.empty[SubstreamKey, SubstreamSubscriber[Any]] val inputSubstreamManagement: Receive = { - case SubstreamStreamOnSubscribe(key, subscription) ⇒ substreamInputs(key).substreamOnSubscribe(subscription) - case SubstreamOnNext(key, element) ⇒ substreamInputs(key).substreamOnNext(element) - case SubstreamOnComplete(key) ⇒ { + case SubstreamStreamOnSubscribe(key, subscription) ⇒ + substreamInputs(key).substreamOnSubscribe(subscription) + waitingForOnSubscribe -= key + case SubstreamOnNext(key, element) ⇒ + substreamInputs(key).substreamOnNext(element) + case SubstreamOnComplete(key) ⇒ substreamInputs(key).substreamOnComplete() substreamInputs -= key - } - case SubstreamOnError(key, e) ⇒ substreamInputs(key).substreamOnError(e) - + case SubstreamOnError(key, e) ⇒ + substreamInputs(key).substreamOnError(e) } def createSubstreamInput(): SubstreamInput = { @@ -367,7 +371,9 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ def createAndSubscribeSubstreamInput(p: Publisher[Any]): SubstreamInput = { val inputs = createSubstreamInput() - p.subscribe(new SubstreamSubscriber(self, inputs.key)) + val sub = new SubstreamSubscriber[Any](self, inputs.key) + waitingForOnSubscribe(inputs.key) = sub + p.subscribe(sub) inputs } @@ -378,13 +384,25 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ } protected def failInputs(e: Throwable): Unit = { + cancelWaitingForOnSubscribe() substreamInputs.values foreach (_.cancel()) } protected def finishInputs(): Unit = { + cancelWaitingForOnSubscribe() substreamInputs.values foreach (_.cancel()) } + private def cancelWaitingForOnSubscribe(): Unit = + waitingForOnSubscribe.valuesIterator.foreach { sub ⇒ + sub.getAndSet(CancelledSubscription) match { + case null ⇒ // we were first + case subscription ⇒ + // SubstreamOnSubscribe is still in flight and will not arrive + subscription.cancel() + } + } + } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala index d43ac5e877..fe0a8634a8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala @@ -161,6 +161,11 @@ private[akka] trait Pump { transferState = WaitingForUpstreamSubscription(waitForUpstream, andThen) } + final def waitForUpstreams(waitForUpstream: Int): Unit = { + require(waitForUpstream >= 1, s"waitForUpstream must be >= 1 (was $waitForUpstream)") + transferState = WaitingForUpstreamSubscription(waitForUpstream, TransferPhase(transferState)(currentAction)) + } + def gotUpstreamSubscription(): Unit = { transferState match { case WaitingForUpstreamSubscription(1, andThen) ⇒