diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index f258974095..51f36689f8 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -348,7 +348,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. } class TestSetup { - val requests = TestPublisher.manualProbe[HttpRequest] + val requests = TestPublisher.manualProbe[HttpRequest]() val responses = TestSubscriber.manualProbe[HttpResponse] val remoteAddress = new InetSocketAddress("example.com", 0) @@ -357,7 +357,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. val (netOut, netIn) = { val netOut = TestSubscriber.manualProbe[ByteString] - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() FlowGraph.closed(OutgoingConnectionBlueprint(Host(remoteAddress), settings, NoLogging)) { implicit b ⇒ client ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index aa97aeebe9..2a65b2ac00 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -406,7 +406,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) @@ -429,7 +429,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) @@ -453,7 +453,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ChunkStreamPart] + val data = TestPublisher.manualProbe[ChunkStreamPart]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) @@ -477,7 +477,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Connection: close | |""".stripMarginWithNewline("\r\n")) - val data = TestPublisher.manualProbe[ByteString] + val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala index f901202d6e..9a1b2cb8a8 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala @@ -29,13 +29,13 @@ abstract class HttpServerTestSetupBase { implicit def materializer: FlowMaterializer val requests = TestSubscriber.manualProbe[HttpRequest] - val responses = TestPublisher.manualProbe[HttpResponse] + val responses = TestPublisher.manualProbe[HttpResponse]() def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) def remoteAddress: Option[InetSocketAddress] = None val (netIn, netOut) = { - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() val netOut = TestSubscriber.manualProbe[ByteString] FlowGraph.closed(HttpServerBluePrint(settings, remoteAddress = remoteAddress, log = NoLogging)) { implicit b ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 6f529fd7a3..67804272ce 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -222,7 +222,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a strict message larger than configured maximum frame size" in pending "for a streamed message" in new ServerTestSetup { val data = ByteString("abcdefg", "ASCII") - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) netOutSub.request(6) pushMessage(msg) @@ -245,7 +245,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a streamed message with a chunk being larger than configured maximum frame size" in pending "and mask input on the client side" in new ClientTestSetup { val data = ByteString("abcdefg", "ASCII") - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) netOutSub.request(7) pushMessage(msg) @@ -278,7 +278,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a strict message larger than configured maximum frame size" in pending "for a streamed message" in new ServerTestSetup { val text = "äbcd€fg" - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(6) pushMessage(msg) @@ -310,7 +310,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { println(half1(0).toInt.toHexString) println(half2(0).toInt.toHexString) - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(6) @@ -327,7 +327,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { "for a streamed message with a chunk being larger than configured maximum frame size" in pending "and mask input on the client side" in new ClientTestSetup { val text = "abcdefg" - val pub = TestPublisher.manualProbe[String] + val pub = TestPublisher.manualProbe[String]() val msg = TextMessage.Streamed(Source(pub)) netOutSub.request(5) pushMessage(msg) @@ -382,7 +382,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { s.request(2) sub.expectNext(ByteString("123", "ASCII")) - val outPub = TestPublisher.manualProbe[ByteString] + val outPub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(outPub)) netOutSub.request(10) pushMessage(msg) @@ -459,7 +459,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { messageIn.expectComplete() // sending another message is allowed before closing (inherently racy) - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -504,7 +504,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { // sending another message is allowed before closing (inherently racy) - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -549,7 +549,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { messageInSub.request(10) // send half a message - val pub = TestPublisher.manualProbe[ByteString] + val pub = TestPublisher.manualProbe[ByteString]() val msg = BinaryMessage.Streamed(Source(pub)) pushMessage(msg) expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) @@ -765,11 +765,11 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { protected def serverSide: Boolean protected def closeTimeout: FiniteDuration = 1.second - val netIn = TestPublisher.manualProbe[ByteString] + val netIn = TestPublisher.manualProbe[ByteString]() val netOut = TestSubscriber.manualProbe[ByteString] val messageIn = TestSubscriber.manualProbe[Message] - val messageOut = TestPublisher.manualProbe[Message] + val messageOut = TestPublisher.manualProbe[Message]() val messageHandler: Flow[Message, Message, Unit] = Flow.wrap { diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala index b8c1b62b57..f6783982a1 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala @@ -30,7 +30,10 @@ abstract class AkkaIdentityProcessorVerification[T](env: TestEnvironment, publis def processorFromFlow(flow: Flow[T, T, _])(implicit mat: ActorFlowMaterializer): Processor[T, T] = { val (sub: Subscriber[T], pub: Publisher[T]) = flow.runWith(Source.subscriber[T], Sink.publisher[T]) + processorFromSubscriberAndPublisher(sub, pub) + } + def processorFromSubscriberAndPublisher(sub: Subscriber[T], pub: Publisher[T]): Processor[T, T] = { new Processor[T, T] { override def onSubscribe(s: Subscription): Unit = sub.onSubscribe(s) override def onError(t: Throwable): Unit = sub.onError(t) diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala new file mode 100644 index 0000000000..2bf5dad809 --- /dev/null +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala @@ -0,0 +1,34 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.tck + +import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl.Flow +import org.reactivestreams.Processor +import akka.stream.impl.VirtualProcessor + +class VirtualProcessorTest extends AkkaIdentityProcessorVerification[Int] { + + override def createIdentityProcessor(maxBufferSize: Int): Processor[Int, Int] = { + implicit val materializer = ActorFlowMaterializer()(system) + + val identity = processorFromFlow(Flow[Int].map(elem ⇒ elem).named("identity")) + val left, right = new VirtualProcessor[Int] + left.subscribe(identity) + identity.subscribe(right) + processorFromSubscriberAndPublisher(left, right) + } + + override def createElement(element: Int): Int = element + +} + +class VirtualProcessorSingleTest extends AkkaIdentityProcessorVerification[Int] { + + override def createIdentityProcessor(maxBufferSize: Int): Processor[Int, Int] = + new VirtualProcessor[Int] + + override def createElement(element: Int): Int = element + +} diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index 4f488f5fda..756c1e836d 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -50,7 +50,7 @@ object TestPublisher { /** * Probe that implements [[org.reactivestreams.Publisher]] interface. */ - def manualProbe[T]()(implicit system: ActorSystem): ManualProbe[T] = new ManualProbe() + def manualProbe[T](autoOnSubscribe: Boolean = true)(implicit system: ActorSystem): ManualProbe[T] = new ManualProbe(autoOnSubscribe) /** * Probe that implements [[org.reactivestreams.Publisher]] interface and tracks demand. @@ -62,7 +62,7 @@ object TestPublisher { * This probe does not track demand. Therefore you need to expect demand before sending * elements downstream. */ - class ManualProbe[I] private[TestPublisher] ()(implicit system: ActorSystem) extends Publisher[I] { + class ManualProbe[I] private[TestPublisher] (autoOnSubscribe: Boolean = true)(implicit system: ActorSystem) extends Publisher[I] { type Self <: ManualProbe[I] @@ -76,7 +76,7 @@ object TestPublisher { def subscribe(subscriber: Subscriber[_ >: I]): Unit = { val subscription: PublisherProbeSubscription[I] = new PublisherProbeSubscription[I](subscriber, probe) probe.ref ! Subscribe(subscription) - subscriber.onSubscribe(subscription) + if (autoOnSubscribe) subscriber.onSubscribe(subscription) } /** @@ -315,6 +315,12 @@ object TestSubscriber { def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(f: PartialFunction[SubscriberEvent, T]): immutable.Seq[T] = probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) + def receiveWithin(max: FiniteDuration, messages: Int = Int.MaxValue): immutable.Seq[I] = + probe.receiveWhile(max, max, messages) { + case OnNext(i) ⇒ Some(i.asInstanceOf[I]) + case _ ⇒ None + }.flatten + def within[T](max: FiniteDuration)(f: ⇒ T): T = probe.within(0.seconds, max)(f) def onSubscribe(subscription: Subscription): Unit = probe.ref ! OnSubscribe(subscription) @@ -396,6 +402,8 @@ private[testkit] object StreamTestKit { def sendNext(element: I): Unit = subscriber.onNext(element) def sendComplete(): Unit = subscriber.onComplete() def sendError(cause: Exception): Unit = subscriber.onError(cause) + + def sendOnSubscribe(): Unit = subscriber.onSubscribe(this) } final class ProbeSource[T](val attributes: OperationAttributes, shape: SourceShape[T])(implicit system: ActorSystem) extends SourceModule[T, TestPublisher.Probe[T]](shape) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala index 37e3e3ce29..7d3301a94a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala @@ -17,7 +17,7 @@ class ActorInterpreterLifecycleSpec extends AkkaSpec with InterpreterLifecycleSp implicit val mat = ActorFlowMaterializer() class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int] private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") val actor = system.actorOf(props) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala index 8166f2c29d..7352be1a2c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala @@ -24,7 +24,7 @@ class ActorInterpreterSpec extends AkkaSpec { implicit val mat = ActorFlowMaterializer() class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int] private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") val actor = system.actorOf(props) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala index 9fd0f0fe09..e645e1f7d3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala @@ -186,7 +186,6 @@ class StreamLayoutSpec extends AkkaSpec { assignPort(outPort, publisher) } } - override protected def createIdentityProcessor: Processor[Any, Any] = null // Not used in test } def checkMaterialized(topLevel: Module): (Set[TestPublisher], Set[TestSubscriber]) = { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala index f7aaea5f5f..ba5cb29e9a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala @@ -71,6 +71,29 @@ class FlowConcatAllSpec extends AkkaSpec { subUpstream.expectCancellation() } + "on onError on master stream cancel the currently opening substream and signal error" in assertAllStagesStopped { + val publisher = TestPublisher.manualProbe[Source[Int, _]]() + val subscriber = TestSubscriber.manualProbe[Int]() + Source(publisher).flatten(FlattenStrategy.concat).to(Sink(subscriber)).run() + + val upstream = publisher.expectSubscription() + val downstream = subscriber.expectSubscription() + downstream.request(1000) + + val substreamPublisher = TestPublisher.manualProbe[Int](autoOnSubscribe = false) + val substreamFlow = Source(substreamPublisher) + upstream.expectRequest() + upstream.sendNext(substreamFlow) + val subUpstream = substreamPublisher.expectSubscription() + + upstream.sendError(testException) + + subUpstream.sendOnSubscribe() + + subscriber.expectError(testException) + subUpstream.expectCancellation() + } + "on onError on open substream, cancel the master stream and signal error " in assertAllStagesStopped { val publisher = TestPublisher.manualProbe[Source[Int, _]]() val subscriber = TestSubscriber.manualProbe[Int]() @@ -112,6 +135,29 @@ class FlowConcatAllSpec extends AkkaSpec { upstream.expectCancellation() } + "on cancellation cancel the currently opening substream and the master stream" in assertAllStagesStopped { + val publisher = TestPublisher.manualProbe[Source[Int, _]]() + val subscriber = TestSubscriber.manualProbe[Int]() + Source(publisher).flatten(FlattenStrategy.concat).to(Sink(subscriber)).run() + + val upstream = publisher.expectSubscription() + val downstream = subscriber.expectSubscription() + downstream.request(1000) + + val substreamPublisher = TestPublisher.manualProbe[Int](autoOnSubscribe = false) + val substreamFlow = Source(substreamPublisher) + upstream.expectRequest() + upstream.sendNext(substreamFlow) + val subUpstream = substreamPublisher.expectSubscription() + + downstream.cancel() + + subUpstream.sendOnSubscribe() + + subUpstream.expectCancellation() + upstream.expectCancellation() + } + "pass along early cancellation" in assertAllStagesStopped { val up = TestPublisher.manualProbe[Source[Int, _]]() val down = TestSubscriber.manualProbe[Int]() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala index dc190188a1..c65e6d9fe1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala @@ -258,8 +258,8 @@ class FlowGraphCompileSpec extends AkkaSpec { "build with implicits and variance" in { FlowGraph.closed() { implicit b ⇒ - def appleSource = b.add(Source(TestPublisher.manualProbe[Apple])) - def fruitSource = b.add(Source(TestPublisher.manualProbe[Fruit])) + def appleSource = b.add(Source(TestPublisher.manualProbe[Apple]())) + def fruitSource = b.add(Source(TestPublisher.manualProbe[Fruit]())) val outA = b add Sink(TestSubscriber.manualProbe[Fruit]()) val outB = b add Sink(TestSubscriber.manualProbe[Fruit]()) val merge = b add Merge[Fruit](11) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala index 6a7cfa3113..c39eb36068 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala @@ -15,9 +15,10 @@ import akka.stream.testkit._ import akka.stream.testkit.Utils._ import akka.testkit.{ EventFilter, TestProbe } import com.typesafe.config.ConfigFactory - import scala.concurrent.duration._ import scala.util.control.NoStackTrace +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.testkit.scaladsl.TestSource class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.receive=off\nakka.loglevel=INFO")) { @@ -222,8 +223,7 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } "allow early finish" in assertAllStagesStopped { - val p = TestPublisher.manualProbe[Int]() - val p2 = Source(p). + val (p1, p2) = TestSource.probe[Int]. transform(() ⇒ new PushStage[Int, Int] { var s = "" override def onPush(element: Int, ctx: Context[Int]) = { @@ -233,18 +233,14 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug else ctx.push(element) } - }). - runWith(Sink.publisher) - val proc = p.expectSubscription - val c = TestSubscriber.manualProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - proc.sendNext(1) - proc.sendNext(2) - c.expectNext(1) - c.expectComplete() - proc.expectCancellation() + }) + .toMat(TestSink.probe[Int])(Keep.both).run + p2.request(10) + p1.sendNext(1) + .sendNext(2) + p2.expectNext(1) + .expectComplete() + p1.expectCancellation() } "report error when exception is thrown" in assertAllStagesStopped { @@ -261,16 +257,13 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } } }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() + runWith(TestSink.probe[Int]) EventFilter[IllegalArgumentException]("two not allowed") intercept { - subscription.request(100) - subscriber.expectNext(1) - subscriber.expectNext(1) - subscriber.expectError().getMessage should be("two not allowed") - subscriber.expectNoMsg(200.millis) + p2.request(100) + .expectNext(1) + .expectNext(1) + .expectError().getMessage should be("two not allowed") + p2.expectNoMsg(200.millis) } } @@ -288,65 +281,56 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } }). filter(elem ⇒ elem != 1). // it's undefined if element 1 got through before the error or not - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() + runWith(TestSink.probe[Int]) EventFilter[IllegalArgumentException]("two not allowed") intercept { - subscription.request(100) - subscriber.expectNext(100) - subscriber.expectNext(101) - subscriber.expectComplete() - subscriber.expectNoMsg(200.millis) + p2.request(100) + .expectNext(100) + .expectNext(101) + .expectComplete() + .expectNoMsg(200.millis) } } "support cancel as expected" in assertAllStagesStopped { - val p = Source(List(1, 2, 3)).runWith(Sink.publisher) - val p2 = Source(p). + val p = Source(1 to 100).runWith(Sink.publisher) + val received = Source(p). transform(() ⇒ new StatefulStage[Int, Int] { override def initial = new State { override def onPush(elem: Int, ctx: Context[Int]) = emit(Iterator(elem, elem), ctx) } - }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(2) - subscriber.expectNext(1) - subscription.cancel() - subscriber.expectNext(1) - subscriber.expectNoMsg(500.millis) - subscription.request(2) - subscriber.expectNoMsg(200.millis) + }) + .runWith(TestSink.probe[Int]()) + .request(1000) + .expectNext(1) + .cancel() + .receiveWithin(1.second) + received.size should be < 200 + received.foldLeft((true, 1)) { + case ((flag, last), next) ⇒ (flag && (last == next || last == next - 1), next) + }._1 should be(true) } "support producing elements from empty inputs" in assertAllStagesStopped { val p = Source(List.empty[Int]).runWith(Sink.publisher) - val p2 = Source(p). + Source(p). transform(() ⇒ new StatefulStage[Int, Int] { override def initial = new State { override def onPush(elem: Int, ctx: Context[Int]) = ctx.pull() } override def onUpstreamFinish(ctx: Context[Int]) = terminationEmit(Iterator(1, 2, 3), ctx) - }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(4) - subscriber.expectNext(1) - subscriber.expectNext(2) - subscriber.expectNext(3) - subscriber.expectComplete() + }) + .runWith(TestSink.probe[Int]) + .request(4) + .expectNext(1) + .expectNext(2) + .expectNext(3) + .expectComplete() } "support converting onComplete into onError" in { - val subscriber = TestSubscriber.manualProbe[Int]() Source(List(5, 1, 2, 3)).transform(() ⇒ new PushStage[Int, Int] { var expectedNumberOfElements: Option[Int] = None var count = 0 @@ -365,15 +349,12 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug throw new RuntimeException(s"Expected $expected, got $count") with NoStackTrace case _ ⇒ ctx.finish() } - }).to(Sink(subscriber)).run() - - val subscription = subscriber.expectSubscription() - subscription.request(10) - - subscriber.expectNext(1) - subscriber.expectNext(2) - subscriber.expectNext(3) - subscriber.expectError().getMessage should be("Expected 5, got 3") + }).runWith(TestSink.probe[Int]) + .request(10) + .expectNext(1) + .expectNext(2) + .expectNext(3) + .expectError().getMessage should be("Expected 5, got 3") } "be safe to reuse" in { @@ -387,17 +368,15 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } }) - val s1 = TestSubscriber.manualProbe[Int]() - flow.to(Sink(s1)).run() - s1.expectSubscription().request(3) - s1.expectNext(1, 2, 3) - s1.expectComplete() + flow.runWith(TestSink.probe[Int]) + .request(3) + .expectNext(1, 2, 3) + .expectComplete() - val s2 = TestSubscriber.manualProbe[Int]() - flow.to(Sink(s2)).run() - s2.expectSubscription().request(3) - s2.expectNext(1, 2, 3) - s2.expectComplete() + flow.runWith(TestSink.probe[Int]) + .request(3) + .expectNext(1, 2, 3) + .expectComplete() } "handle early cancelation" in assertAllStagesStopped { @@ -418,7 +397,7 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug downstream.cancel() onDownstreamFinishProbe.expectMsg("onDownstreamFinish") - val up = TestPublisher.manualProbe[Int] + val up = TestPublisher.manualProbe[Int]() up.subscribe(s) val upsub = up.expectSubscription() upsub.expectCancellation() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala index b7437831f7..b7d6843413 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiMergeSpec.scala @@ -570,7 +570,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "propagate failure" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -587,7 +587,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -607,7 +607,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure for user thrown exception" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -626,7 +626,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "emit failure for user thrown exception in onComplete" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ @@ -670,7 +670,7 @@ class GraphFlexiMergeSpec extends AkkaSpec { } "support finish from onInput" in assertAllStagesStopped { - val publisher = TestPublisher.manualProbe[String] + val publisher = TestPublisher.manualProbe[String]() val completionProbe = TestProbe() val p = FlowGraph.closed(out) { implicit b ⇒ o ⇒ diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala index f2a90a8d5d..a4fd75a4b9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSpec.scala @@ -145,8 +145,8 @@ class GraphMergeSpec extends TwoStreamsSetup { } "pass along early cancellation" in assertAllStagesStopped { - val up1 = TestPublisher.manualProbe[Int] - val up2 = TestPublisher.manualProbe[Int] + val up1 = TestPublisher.manualProbe[Int]() + val up2 = TestPublisher.manualProbe[Int]() val down = TestSubscriber.manualProbe[Int]() val src1 = Source.subscriber[Int] diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala index 1c3c49dd7b..83eb079ad7 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala @@ -43,6 +43,13 @@ class PublisherSinkSpec extends AkkaSpec { Source(1 to 100).to(Sink(sub)).run() Await.result(Source(pub).grouped(1000).runWith(Sink.head), 3.seconds) should ===(1 to 100) } + + "be able to use Publisher in materialized value transformation" in { + val f = Source(1 to 3).runWith( + Sink.publisher[Int].mapMaterializedValue(p ⇒ Source(p).runFold(0)(_ + _))) + + Await.result(f, 3.seconds) should be(6) + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index 30167ff7b9..695479c4d1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -122,7 +122,7 @@ class SourceSpec extends AkkaSpec { "Composite Source" must { "merge from many inputs" in { - val probes = Seq.fill(5)(TestPublisher.manualProbe[Int]) + val probes = Seq.fill(5)(TestPublisher.manualProbe[Int]()) val source = Source.subscriber[Int] val out = TestSubscriber.manualProbe[Int] diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala new file mode 100644 index 0000000000..33d3dde68c --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala @@ -0,0 +1,29 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.ActorFlowMaterializer + +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.Utils._ +import scala.concurrent.duration._ + +import scala.concurrent.Await + +class SubscriberSourceSpec extends AkkaSpec("akka.loglevel=DEBUG\nakka.actor.debug.lifecycle=on") { + + implicit val materializer = ActorFlowMaterializer() + + "A SubscriberSource" must { + + "be able to use Subscriber in materialized value transformation" in { + val f = + Source.subscriber[Int].mapMaterializedValue(s ⇒ Source(1 to 3).runWith(Sink(s))) + .runWith(Sink.fold[Int, Int](0)(_ + _)) + + Await.result(f, 3.seconds) should be(6) + } + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index baecbadc69..91056e6908 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -11,7 +11,6 @@ import akka.pattern.ask import akka.stream.actor.ActorSubscriber import akka.stream.impl.GenJunctions.ZipWithModule import akka.stream.impl.Junctions._ -import akka.stream.impl.MultiStreamInputProcessor.SubstreamSubscriber import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.ActorInterpreter import akka.stream.impl.io.SslTlsCipherActor @@ -111,13 +110,11 @@ private[akka] case class ActorFlowMaterializerImpl( } } - override protected def createIdentityProcessor: Processor[Any, Any] = - processorFor(Identity(OperationAttributes.none), OperationAttributes.none, settings)._1 - private def processorFor(op: StageModule, effectiveAttributes: OperationAttributes, effectiveSettings: ActorFlowMaterializerSettings): (Processor[Any, Any], Any) = op match { case DirectProcessor(processorFactory, _) ⇒ processorFactory() + case Identity(attr) ⇒ (new VirtualProcessor, ()) case _ ⇒ val (opprops, mat) = ActorProcessorFactory.props(ActorFlowMaterializerImpl.this, op, effectiveAttributes) val processor = ActorProcessorFactory[Any, Any](actorOf( @@ -294,7 +291,7 @@ private[akka] object ActorProcessorFactory { // Also, otherwise the attributes will not affect the settings properly! val settings = materializer.effectiveSettings(att) op match { - case Identity(_) ⇒ (ActorInterpreter.props(settings, List(fusing.Map(_identity, settings.supervisionDecider)), materializer, att), ()) + case Identity(_) ⇒ throw new AssertionError("Identity cannot end up in ActorProcessorFactory") case Fused(ops, _) ⇒ (ActorInterpreter.props(settings, ops, materializer, att), ()) case Map(f, _) ⇒ (ActorInterpreter.props(settings, List(fusing.Map(f, settings.supervisionDecider)), materializer, att), ()) case Filter(p, _) ⇒ (ActorInterpreter.props(settings, List(fusing.Filter(p, settings.supervisionDecider)), materializer, att), ()) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala index 3414d375fe..48c1f4c22c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala @@ -55,7 +55,7 @@ private[akka] abstract class SourceModule[+Out, +Mat](val shape: SourceShape[Out private[akka] final class SubscriberSource[Out](val attributes: OperationAttributes, shape: SourceShape[Out]) extends SourceModule[Out, Subscriber[Out]](shape) { override def create(context: MaterializationContext): (Publisher[Out], Subscriber[Out]) = { - val processor = new SubscriberSourceVirtualProcessor[Out] + val processor = new VirtualProcessor[Out] (processor, processor) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index c1ce5cfe27..51d7a5def9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -60,9 +60,8 @@ private[akka] class PublisherSink[In](val attributes: OperationAttributes, shape override def toString: String = "PublisherSink" override def create(context: MaterializationContext): (Subscriber[In], Publisher[In]) = { - val pub = new PublisherSinkVirtualPublisher[In] - val sub = new PublisherSinkVirtualSubscriber[In](pub) - (sub, pub) + val proc = new VirtualProcessor[In] + (proc, proc) } override protected def newInstance(shape: SinkShape[In]): SinkModule[In, Publisher[In]] = new PublisherSink[In](attributes, shape) diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index a8967360d9..473ef4bce4 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -4,7 +4,6 @@ package akka.stream.impl import java.util.concurrent.atomic.{ AtomicInteger, AtomicBoolean, AtomicReference } - import akka.stream.impl.StreamLayout.Module import akka.stream.scaladsl.Keep import akka.stream._ @@ -12,6 +11,8 @@ import org.reactivestreams.{ Processor, Subscription, Publisher, Subscriber } import scala.collection.mutable import scala.util.control.NonFatal import akka.event.Logging.simpleName +import scala.annotation.tailrec +import java.util.concurrent.atomic.AtomicLong /** * INTERNAL API @@ -106,9 +107,8 @@ private[akka] object StreamLayout { AmorphousShape(shape.inlets ++ that.shape.inlets, shape.outlets ++ that.shape.outlets), downstreams ++ that.downstreams, upstreams ++ that.upstreams, - if (f eq Keep.left) matComputation1 - else if (f eq Keep.right) matComputation2 - else Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), + // would like to optimize away this allocation for Keep.{left,right} but that breaks side-effecting transformations + Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), attributes) } @@ -293,38 +293,108 @@ private[akka] object StreamLayout { } } -private[stream] final class SubscriberSourceVirtualProcessor[T] extends Processor[T, T] { - @volatile private var subscriber: Subscriber[_ >: T] = null +private[stream] object VirtualProcessor { + sealed trait Termination + case object Allowed extends Termination + case object Completed extends Termination + case class Failed(ex: Throwable) extends Termination - override def subscribe(s: Subscriber[_ >: T]): Unit = subscriber = s - - override def onError(t: Throwable): Unit = subscriber.onError(t) - override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) - override def onComplete(): Unit = subscriber.onComplete() - override def onNext(t: T): Unit = subscriber.onNext(t) + private object InertSubscriber extends Subscriber[Any] { + override def onSubscribe(s: Subscription): Unit = s.cancel() + override def onNext(elem: Any): Unit = () + override def onError(thr: Throwable): Unit = () + override def onComplete(): Unit = () + } } -/** - * INTERNAL API - */ -private[stream] final class PublisherSinkVirtualSubscriber[T](val owner: PublisherSinkVirtualPublisher[T]) extends Subscriber[T] { - override def onSubscribe(s: Subscription): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onError(t: Throwable): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onComplete(): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onNext(t: T): Unit = throw new UnsupportedOperationException("This method should not be called") -} +private[stream] final class VirtualProcessor[T] extends Processor[T, T] { + import VirtualProcessor._ + import ReactiveStreamsCompliance._ + + private val subscriptionStatus = new AtomicReference[AnyRef] + private val terminationStatus = new AtomicReference[Termination] -/** - * INTERNAL API - */ -private[stream] final class PublisherSinkVirtualPublisher[T]() extends Publisher[T] { - @volatile var realPublisher: Publisher[T] = null override def subscribe(s: Subscriber[_ >: T]): Unit = { - val sub = realPublisher.subscribe(s) - // unreference the realPublisher to facilitate GC and - // Sink.publisher is supposed to reject additional subscribers anyway - realPublisher = RejectAdditionalSubscribers[T] - sub + requireNonNullSubscriber(s) + if (subscriptionStatus.compareAndSet(null, s)) () // wait for onSubscribe + else + subscriptionStatus.get match { + case sub: Subscriber[_] ⇒ rejectAdditionalSubscriber(s, "VirtualProcessor") + case sub: Sub ⇒ + try { + subscriptionStatus.set(s) + tryOnSubscribe(s, sub) + sub.closeLatch() // allow onNext only now + terminationStatus.getAndSet(Allowed) match { + case null ⇒ // nothing happened yet + case Completed ⇒ tryOnComplete(s) + case Failed(ex) ⇒ tryOnError(s, ex) + case Allowed ⇒ // all good + } + } catch { + case NonFatal(ex) ⇒ sub.cancel() + } + } + } + + override def onSubscribe(s: Subscription): Unit = { + requireNonNullSubscription(s) + val wrapped = new Sub(s) + if (subscriptionStatus.compareAndSet(null, wrapped)) () // wait for Subscriber + else + subscriptionStatus.get match { + case sub: Subscriber[_] ⇒ + terminationStatus.get match { + case Allowed ⇒ + /* + * There is a race condition here: if this thread reads the subscriptionStatus after + * set set() in subscribe() but then sees the terminationStatus before the getAndSet() + * is published then we will rely upon the downstream Subscriber for cancelling this + * Subscription. I only mention this because the TCK requires that we handle this here + * (since the manualSubscriber used there does not expose this behavior). + */ + s.cancel() + case _ ⇒ + tryOnSubscribe(sub, wrapped) + wrapped.closeLatch() // allow onNext only now + terminationStatus.set(Allowed) + } + case sub: Subscription ⇒ + s.cancel() // reject further Subscriptions + } + } + + override def onError(t: Throwable): Unit = { + requireNonNullException(t) + if (terminationStatus.compareAndSet(null, Failed(t))) () // let it be picked up by subscribe() + else tryOnError(subscriptionStatus.get.asInstanceOf[Subscriber[T]], t) + } + + override def onComplete(): Unit = + if (terminationStatus.compareAndSet(null, Completed)) () // let it be picked up by subscribe() + else tryOnComplete(subscriptionStatus.get.asInstanceOf[Subscriber[T]]) + + override def onNext(t: T): Unit = { + requireNonNullElement(t) + tryOnNext(subscriptionStatus.get.asInstanceOf[Subscriber[T]], t) + } + + private final class Sub(s: Subscription) extends AtomicLong with Subscription { + override def cancel(): Unit = { + subscriptionStatus.set(InertSubscriber) + s.cancel() + } + @tailrec + override def request(n: Long): Unit = { + val current = get + if (current < 0) s.request(n) + else if (compareAndSet(current, current + n)) () + else request(n) + } + def closeLatch(): Unit = { + val requested = getAndSet(-1) + if (requested > 0) s.request(requested) + } } } @@ -550,6 +620,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo case mv: MaterializedValueSource[_] ⇒ val pub = new MaterializedValuePublisher materializedValuePublishers ::= pub + materializedValues.put(mv, ()) assignPort(mv.shape.outlet, pub) case atomic if atomic.isAtomic ⇒ materializedValues.put(atomic, materializeAtomic(atomic, subEffectiveAttributes)) @@ -573,8 +644,6 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo protected def materializeAtomic(atomic: Module, effectiveAttributes: OperationAttributes): Any - protected def createIdentityProcessor: Processor[Any, Any] - private def resolveMaterialized(matNode: MaterializedValueNode, materializedValues: collection.Map[Module, Any]): Any = matNode match { case Atomic(m) ⇒ materializedValues(m) case Combine(f, d1, d2) ⇒ f(resolveMaterialized(d1, materializedValues), resolveMaterialized(d2, materializedValues)) @@ -582,24 +651,12 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo case Ignore ⇒ () } - private def attach(p: Publisher[Any], s: Subscriber[Any]) = s match { - case v: PublisherSinkVirtualSubscriber[Any] ⇒ - if (p.isInstanceOf[SubscriberSourceVirtualProcessor[Any]]) { - val injectedProcessor = createIdentityProcessor - v.owner.realPublisher = injectedProcessor - p.subscribe(injectedProcessor) - } else - v.owner.realPublisher = p - case _ ⇒ - p.subscribe(s) - } - final protected def assignPort(in: InPort, subscriber: Subscriber[Any]): Unit = { subscribers(in) = subscriber // Interface (unconnected) ports of the current scope will be wired when exiting the scope if (!currentLayout.inPorts(in)) { val publisher = publishers(currentLayout.upstreams(in)) - if (publisher ne null) attach(publisher, subscriber) + if (publisher ne null) publisher.subscribe(subscriber) } } @@ -608,7 +665,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo // Interface (unconnected) ports of the current scope will be wired when exiting the scope if (!currentLayout.outPorts(out)) { val subscriber = subscribers(currentLayout.downstreams(out)) - if (subscriber ne null) attach(publisher, subscriber) + if (subscriber ne null) publisher.subscribe(subscriber) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala index 1ee755fd25..ff65edb8d2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala @@ -297,7 +297,7 @@ private[akka] abstract class TwoStreamInputProcessor(_settings: ActorFlowMateria private[akka] object MultiStreamInputProcessor { case class SubstreamKey(id: Long) - class SubstreamSubscriber[T](val impl: ActorRef, key: SubstreamKey) extends Subscriber[T] { + class SubstreamSubscriber[T](val impl: ActorRef, key: SubstreamKey) extends AtomicReference[Subscription] with Subscriber[T] { override def onError(cause: Throwable): Unit = { ReactiveStreamsCompliance.requireNonNullException(cause) impl ! SubstreamOnError(key, cause) @@ -309,7 +309,8 @@ private[akka] object MultiStreamInputProcessor { } override def onSubscribe(subscription: Subscription): Unit = { ReactiveStreamsCompliance.requireNonNullSubscription(subscription) - impl ! SubstreamStreamOnSubscribe(key, subscription) + if (compareAndSet(null, subscription)) impl ! SubstreamStreamOnSubscribe(key, subscription) + else subscription.cancel() } } @@ -346,16 +347,19 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ protected def inputBufferSize: Int private val substreamInputs = collection.mutable.Map.empty[SubstreamKey, SubstreamInput] + private val waitingForOnSubscribe = collection.mutable.Map.empty[SubstreamKey, SubstreamSubscriber[Any]] val inputSubstreamManagement: Receive = { - case SubstreamStreamOnSubscribe(key, subscription) ⇒ substreamInputs(key).substreamOnSubscribe(subscription) - case SubstreamOnNext(key, element) ⇒ substreamInputs(key).substreamOnNext(element) - case SubstreamOnComplete(key) ⇒ { + case SubstreamStreamOnSubscribe(key, subscription) ⇒ + substreamInputs(key).substreamOnSubscribe(subscription) + waitingForOnSubscribe -= key + case SubstreamOnNext(key, element) ⇒ + substreamInputs(key).substreamOnNext(element) + case SubstreamOnComplete(key) ⇒ substreamInputs(key).substreamOnComplete() substreamInputs -= key - } - case SubstreamOnError(key, e) ⇒ substreamInputs(key).substreamOnError(e) - + case SubstreamOnError(key, e) ⇒ + substreamInputs(key).substreamOnError(e) } def createSubstreamInput(): SubstreamInput = { @@ -367,7 +371,9 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ def createAndSubscribeSubstreamInput(p: Publisher[Any]): SubstreamInput = { val inputs = createSubstreamInput() - p.subscribe(new SubstreamSubscriber(self, inputs.key)) + val sub = new SubstreamSubscriber[Any](self, inputs.key) + waitingForOnSubscribe(inputs.key) = sub + p.subscribe(sub) inputs } @@ -378,13 +384,25 @@ private[akka] trait MultiStreamInputProcessorLike extends Pump { this: Actor ⇒ } protected def failInputs(e: Throwable): Unit = { + cancelWaitingForOnSubscribe() substreamInputs.values foreach (_.cancel()) } protected def finishInputs(): Unit = { + cancelWaitingForOnSubscribe() substreamInputs.values foreach (_.cancel()) } + private def cancelWaitingForOnSubscribe(): Unit = + waitingForOnSubscribe.valuesIterator.foreach { sub ⇒ + sub.getAndSet(CancelledSubscription) match { + case null ⇒ // we were first + case subscription ⇒ + // SubstreamOnSubscribe is still in flight and will not arrive + subscription.cancel() + } + } + } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala index d43ac5e877..fe0a8634a8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala @@ -161,6 +161,11 @@ private[akka] trait Pump { transferState = WaitingForUpstreamSubscription(waitForUpstream, andThen) } + final def waitForUpstreams(waitForUpstream: Int): Unit = { + require(waitForUpstream >= 1, s"waitForUpstream must be >= 1 (was $waitForUpstream)") + transferState = WaitingForUpstreamSubscription(waitForUpstream, TransferPhase(transferState)(currentAction)) + } + def gotUpstreamSubscription(): Unit = { transferState match { case WaitingForUpstreamSubscription(1, andThen) ⇒