diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index 8f3421c6aa..0f76d57e3d 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -4,6 +4,7 @@ package akka.http.impl.engine.client +import akka.stream.impl.fusing.GraphInterpreter import language.existentials import scala.annotation.tailrec import scala.collection.mutable.ListBuffer @@ -22,6 +23,7 @@ import akka.http.impl.util._ import akka.stream.stage.GraphStage import akka.stream.stage.GraphStageLogic import akka.stream.stage.InHandler +import akka.stream.impl.fusing.SubSource /** * INTERNAL API @@ -69,11 +71,21 @@ private[http] object OutgoingConnectionBlueprint { .mapConcat(conforms) .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .prefixAndTail(1) - .collect { + .filter { + case (Seq(MessageEnd), remaining) ⇒ + SubSource.kill(remaining) + false + case _ ⇒ + true + } + .map { case (Seq(ResponseStart(statusCode, protocol, headers, createEntity, _)), entityParts) ⇒ val entity = createEntity(entityParts) withSizeLimit parserSettings.maxContentLength HttpResponse(statusCode, headers, entity, protocol) - case (Seq(MessageStartError(_, info)), _) ⇒ throw IllegalResponseException(info) + case (Seq(MessageStartError(_, info)), tail) ⇒ + // Tails can be empty, but still need one pull to figure that out -- never drop tails. + SubSource.kill(tail) + throw IllegalResponseException(info) }.concatSubstreams val core = BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒ @@ -198,7 +210,8 @@ private[http] object OutgoingConnectionBlueprint { val getNextData = () ⇒ { waitingForMethod = false - pull(dataInput) + if (!isClosed(dataInput)) pull(dataInput) + else completeStage() } @tailrec def drainParser(current: ResponseOutput, b: ListBuffer[ResponseOutput] = ListBuffer.empty): Unit = { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala index 5315478750..13b2f4fa11 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala @@ -5,18 +5,18 @@ package akka.http.impl.engine.parsing import akka.http.ParserSettings - +import akka.stream.impl.fusing.GraphInterpreter import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.parboiled2.CharPredicate -import akka.stream.scaladsl.Source +import akka.stream.scaladsl.{ Sink, Source } import akka.stream.stage._ import akka.util.ByteString import akka.http.scaladsl.model._ import akka.http.impl.util._ import headers._ - import scala.collection.mutable.ListBuffer +import akka.stream.impl.fusing.SubSource /** * INTERNAL API @@ -173,7 +173,11 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, emit(bytes) }, emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = { - (headers, ct, bytes) ⇒ emit(BodyPartStart(headers, _ ⇒ HttpEntity.Strict(ct, bytes))) + (headers, ct, bytes) ⇒ + emit(BodyPartStart(headers, { rest ⇒ + SubSource.kill(rest) + HttpEntity.Strict(ct, bytes) + })) })(input: ByteString, offset: Int): StateResult = try { @tailrec def rec(index: Int): StateResult = { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala index fe6a813a8b..26d83b1852 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala @@ -5,8 +5,10 @@ package akka.http.impl.engine.parsing import akka.http.scaladsl.model._ +import akka.stream.impl.fusing.GraphInterpreter +import akka.stream.scaladsl.{ Sink, Source } import akka.util.ByteString -import akka.stream.scaladsl.Source +import akka.stream.impl.fusing.SubSource /** * INTERNAL API @@ -64,7 +66,11 @@ private[http] object ParserOutput { sealed abstract class EntityCreator[-A <: ParserOutput, +B >: HttpEntity.Strict <: HttpEntity] extends (Source[A, Unit] ⇒ B) final case class StrictEntityCreator(entity: HttpEntity.Strict) extends EntityCreator[ParserOutput, HttpEntity.Strict] { - def apply(parts: Source[ParserOutput, Unit]) = entity + def apply(parts: Source[ParserOutput, Unit]) = { + // We might need to drain stray empty tail streams which will be read by no one. + SubSource.kill(parts) + entity + } } final case class StreamedEntityCreator[-A <: ParserOutput, +B >: HttpEntity.Strict <: HttpEntity](creator: Source[A, Unit] ⇒ B) extends EntityCreator[A, B] { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 653b48e321..2c54a09db4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -6,6 +6,7 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress import java.util.Random +import akka.stream.impl.fusing.GraphInterpreter import scala.collection.immutable import org.reactivestreams.{ Publisher, Subscriber } import scala.util.control.NonFatal @@ -25,6 +26,8 @@ import akka.stream.io._ import akka.stream.scaladsl._ import akka.stream.stage._ import akka.util.ByteString +import akka.http.scaladsl.model.ws.Message +import akka.stream.impl.fusing.SubSource /** * INTERNAL API @@ -48,13 +51,13 @@ import akka.util.ByteString * +----------+ +-------------+ Context +-----------+ */ private[http] object HttpServerBluePrint { - def apply(settings: ServerSettings, remoteAddress: Option[InetSocketAddress], log: LoggingAdapter)(implicit mat: Materializer): Http.ServerLayer = { + def apply(settings: ServerSettings, remoteAddress: Option[InetSocketAddress], log: LoggingAdapter): Http.ServerLayer = { val theStack = userHandlerGuard(settings.pipeliningLimit) atop requestPreparation(settings) atop controller(settings, log) atop parsingRendering(settings, log) atop - websocketSupport(settings, log) atop + new ProtocolSwitchStage(settings, log) atop unwrapTls theStack.withAttributes(HttpAttributes.remoteAddress(remoteAddress)) @@ -63,28 +66,13 @@ private[http] object HttpServerBluePrint { val unwrapTls: BidiFlow[ByteString, SslTlsOutbound, SslTlsInbound, ByteString, Unit] = BidiFlow.fromFlows(Flow[ByteString].map(SendBytes), Flow[SslTlsInbound].collect { case x: SessionBytes ⇒ x.bytes }) - /** Wrap an HTTP implementation with support for switching to Websocket */ - def websocketSupport(settings: ServerSettings, log: LoggingAdapter)(implicit mat: Materializer): BidiFlow[ResponseRenderingOutput, ByteString, ByteString, ByteString, Unit] = { - val ws = websocketSetup - - BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒ - import GraphDSL.Implicits._ - - val switch = b.add(new ProtocolSwitchStage(ws.installHandler, settings.websocketRandomFactory, log)) - - switch.toWs ~> ws.websocketFlow ~> switch.fromWs - - BidiShape(switch.fromHttp, switch.toNet, switch.fromNet, switch.toHttp) - }) - } - def parsingRendering(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingContext, ResponseRenderingOutput, ByteString, RequestOutput, Unit] = BidiFlow.fromFlows(rendering(settings, log), parsing(settings, log)) def controller(settings: ServerSettings, log: LoggingAdapter): BidiFlow[HttpResponse, ResponseRenderingContext, RequestOutput, RequestOutput, Unit] = BidiFlow.fromGraph(new ControllerStage(settings, log)).reversed - def requestPreparation(settings: ServerSettings)(implicit mat: Materializer): BidiFlow[HttpResponse, HttpResponse, RequestOutput, HttpRequest, Unit] = + def requestPreparation(settings: ServerSettings): BidiFlow[HttpResponse, HttpResponse, RequestOutput, HttpRequest, Unit] = BidiFlow.fromFlows(Flow[HttpResponse], Flow[RequestOutput] .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) @@ -93,7 +81,7 @@ private[http] object HttpServerBluePrint { .concatSubstreams .via(requestStartOrRunIgnore(settings))) - def requestStartOrRunIgnore(settings: ServerSettings)(implicit mat: Materializer): Flow[(ParserOutput.RequestOutput, Source[ParserOutput.RequestOutput, Unit]), HttpRequest, Unit] = + def requestStartOrRunIgnore(settings: ServerSettings): Flow[(ParserOutput.RequestOutput, Source[ParserOutput.RequestOutput, Unit]), HttpRequest, Unit] = Flow.fromGraph(new GraphStage[FlowShape[(RequestOutput, Source[RequestOutput, Unit]), HttpRequest]] { val in = Inlet[(RequestOutput, Source[RequestOutput, Unit])]("RequestStartThenRunIgnore.in") val out = Outlet[HttpRequest]("RequestStartThenRunIgnore.out") @@ -115,7 +103,7 @@ private[http] object HttpServerBluePrint { push(out, HttpRequest(effectiveMethod, uri, effectiveHeaders, entity, protocol)) case (wat, src) ⇒ - src.runWith(Sink.ignore) + SubSource.kill(src) pull(in) } }) @@ -355,124 +343,111 @@ private[http] object HttpServerBluePrint { def userHandlerGuard(pipeliningLimit: Int): BidiFlow[HttpResponse, HttpResponse, HttpRequest, HttpRequest, Unit] = One2OneBidiFlow[HttpRequest, HttpResponse](pipeliningLimit).reversed - private trait WebsocketSetup { - def websocketFlow: Flow[ByteString, ByteString, Any] - def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: Materializer): Unit - } - private def websocketSetup: WebsocketSetup = { - val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]] - val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]] + private class ProtocolSwitchStage(settings: ServerSettings, log: LoggingAdapter) + extends GraphStage[BidiShape[ResponseRenderingOutput, ByteString, ByteString, ByteString]] { - val sink = StreamUtils.oneTimePublisherSink[FrameEvent](sinkCell, "frameHandler.in") - val source = StreamUtils.oneTimeSubscriberSource[FrameEvent](sourceCell, "frameHandler.out") - - val flow = Websocket.framing.join(Flow.fromSinkAndSourceMat(sink, source)(Keep.none)) - - new WebsocketSetup { - def websocketFlow: Flow[ByteString, ByteString, Any] = flow - - def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: Materializer): Unit = - Source.fromPublisher(sinkCell.value) - .via(handlerFlow) - .to(Sink.fromSubscriber(sourceCell.value)) - .run() - } - } - - private case class ProtocolSwitchShape( - fromNet: Inlet[ByteString], - toNet: Outlet[ByteString], - fromHttp: Inlet[ResponseRenderingOutput], - toHttp: Outlet[ByteString], - fromWs: Inlet[ByteString], - toWs: Outlet[ByteString]) extends Shape { - def inlets: immutable.Seq[Inlet[_]] = Vector(fromNet, fromHttp, fromWs) - def outlets: immutable.Seq[Outlet[_]] = Vector(toNet, toHttp, toWs) - - def deepCopy(): Shape = - ProtocolSwitchShape(fromNet.carbonCopy(), toNet.carbonCopy(), fromHttp.carbonCopy(), toHttp.carbonCopy(), fromWs.carbonCopy(), toWs.carbonCopy()) - - def copyFromPorts(inlets: immutable.Seq[Inlet[_]], outlets: immutable.Seq[Outlet[_]]): Shape = { - require(inlets.size == 3 && outlets.size == 3, s"ProtocolSwitchShape must have 3 inlets and outlets but had ${inlets.size} / ${outlets.size}") - ProtocolSwitchShape( - inlets(0).asInstanceOf[Inlet[ByteString]], - outlets(0).asInstanceOf[Outlet[ByteString]], - inlets(1).asInstanceOf[Inlet[ResponseRenderingOutput]], - outlets(1).asInstanceOf[Outlet[ByteString]], - inlets(2).asInstanceOf[Inlet[ByteString]], - outlets(2).asInstanceOf[Outlet[ByteString]]) - } - } - - private class ProtocolSwitchStage(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit, - websocketRandomFactory: () ⇒ Random, log: LoggingAdapter) extends GraphStage[ProtocolSwitchShape] { private val fromNet = Inlet[ByteString]("fromNet") private val toNet = Outlet[ByteString]("toNet") private val toHttp = Outlet[ByteString]("toHttp") private val fromHttp = Inlet[ResponseRenderingOutput]("fromHttp") - private val toWs = Outlet[ByteString]("toWs") - private val fromWs = Inlet[ByteString]("fromWs") - override def initialAttributes = Attributes.name("ProtocolSwitchStage") - def shape: ProtocolSwitchShape = ProtocolSwitchShape(fromNet, toNet, fromHttp, toHttp, fromWs, toWs) + override val shape = BidiShape(fromHttp, toNet, fromNet, toHttp) - def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { import akka.http.impl.engine.rendering.ResponseRenderingOutput._ - var websocketHandlerWasInstalled = false + setHandler(fromHttp, new InHandler { + override def onPush(): Unit = + grab(fromHttp) match { + case HttpData(b) ⇒ push(toNet, b) + case SwitchToWebsocket(bytes, handlerFlow) ⇒ + push(toNet, bytes) + complete(toHttp) + cancel(fromHttp) + switchToWebsocket(handlerFlow) + } + }) + setHandler(toNet, new OutHandler { + override def onPull(): Unit = pull(fromHttp) + }) - setHandler(fromHttp, ignoreTerminateInput) - setHandler(toHttp, ignoreTerminateOutput) - setHandler(fromWs, ignoreTerminateInput) - setHandler(toWs, ignoreTerminateOutput) - - val pullNet = () ⇒ pull(fromNet) setHandler(fromNet, new InHandler { - def onPush(): Unit = emit(target, grab(fromNet), pullNet) + def onPush(): Unit = push(toHttp, grab(fromNet)) // propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled // too eagerly - override def onUpstreamFailure(ex: Throwable): Unit = fail(target, ex) + override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex) }) - - val shutdown: () ⇒ Unit = () ⇒ completeStage() - val httpToNet: ResponseRenderingOutput ⇒ Unit = { - case HttpData(b) ⇒ push(toNet, b) - case SwitchToWebsocket(bytes, handlerFlow) ⇒ - push(toNet, bytes) - val frameHandler = handlerFlow match { - case Left(frameHandler) ⇒ frameHandler - case Right(messageHandler) ⇒ - Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory, log = log).join(messageHandler) - } - installHandler(frameHandler) - websocketHandlerWasInstalled = true - } - val wsToNet: ByteString ⇒ Unit = push(toNet, _) - - setHandler(toNet, new OutHandler { - def onPull(): Unit = - if (isHttp) read(fromHttp)(httpToNet, shutdown) - else read(fromWs)(wsToNet, shutdown) - - // toNet cancellation isn't allowed to stop this stage + setHandler(toHttp, new OutHandler { + override def onPull(): Unit = pull(fromNet) override def onDownstreamFinish(): Unit = () }) - def isHttp = !websocketHandlerWasInstalled - def isWS = websocketHandlerWasInstalled - def target = if (websocketHandlerWasInstalled) toWs else toHttp + private var activeTimers = 0 + private def timeout = ActorMaterializer.downcast(materializer).settings.subscriptionTimeoutSettings.timeout + private def addTimeout(s: SubscriptionTimeout): Unit = { + if (activeTimers == 0) setKeepGoing(true) + activeTimers += 1 + scheduleOnce(s, timeout) + } + private def cancelTimeout(s: SubscriptionTimeout): Unit = + if (isTimerActive(s)) { + activeTimers -= 1 + if (activeTimers == 0) setKeepGoing(false) + cancelTimer(s) + } + override def onTimer(timerKey: Any): Unit = timerKey match { + case SubscriptionTimeout(f) ⇒ + activeTimers -= 1 + if (activeTimers == 0) setKeepGoing(false) + f() + } - override def preStart(): Unit = pull(fromNet) + /* + * Websocket support + */ + def switchToWebsocket(handlerFlow: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]): Unit = { + val frameHandler = handlerFlow match { + case Left(frameHandler) ⇒ frameHandler + case Right(messageHandler) ⇒ + Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler) + } + val sinkIn = new SubSinkInlet[ByteString]("FrameSink") + val sourceOut = new SubSourceOutlet[ByteString]("FrameSource") - override def postStop(): Unit = { - // Install a dummy handler to make sure no processors leak because they have - // never been subscribed to, see #17494 and #17551. - if (!websocketHandlerWasInstalled) installHandler(Flow[FrameEvent]) + val timeoutKey = SubscriptionTimeout(() ⇒ { + sourceOut.timeout(timeout) + if (sourceOut.isClosed) completeStage() + }) + addTimeout(timeoutKey) + + sinkIn.setHandler(new InHandler { + override def onPush(): Unit = push(toNet, sinkIn.grab()) + }) + setHandler(toNet, new OutHandler { + override def onPull(): Unit = sinkIn.pull() + }) + + setHandler(fromNet, new InHandler { + override def onPush(): Unit = sourceOut.push(grab(fromNet)) + }) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = { + if (!hasBeenPulled(fromNet)) pull(fromNet) + cancelTimeout(timeoutKey) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = pull(fromNet) + }) + } + }) + + Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) } } } + + private case class SubscriptionTimeout(andThen: () ⇒ Unit) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index 21ef789361..136245b3f4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -5,18 +5,16 @@ package akka.http.impl.engine.ws import java.util.Random - import akka.event.LoggingAdapter +import akka.stream.impl.fusing.GraphInterpreter import akka.util.ByteString - import scala.concurrent.duration._ - import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage._ - import akka.http.impl.util._ import akka.http.scaladsl.model.ws._ +import akka.stream.impl.fusing.SubSource /** * INTERNAL API @@ -91,6 +89,7 @@ private[http] object Websocket { .map { case (seq, remaining) ⇒ seq.head match { case TextMessagePart(text, true) ⇒ + SubSource.kill(remaining) TextMessage.Strict(text) case first @ TextMessagePart(text, false) ⇒ TextMessage( @@ -99,6 +98,7 @@ private[http] object Websocket { case t: TextMessagePart if t.data.nonEmpty ⇒ t.data }) case BinaryMessagePart(data, true) ⇒ + SubSource.kill(remaining) BinaryMessage.Strict(data) case first @ BinaryMessagePart(data, false) ⇒ BinaryMessage( diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala index 7a79a1726a..887032f78b 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala @@ -150,26 +150,6 @@ private[http] object StreamUtils { } } - /** - * Applies a sequence of transformers on one source and returns a sequence of sources with the result. The input source - * will only be traversed once. - */ - def transformMultiple(input: Source[ByteString, Any], transformers: immutable.Seq[Flow[ByteString, ByteString, Any]])(implicit materializer: Materializer): immutable.Seq[Source[ByteString, Any]] = - transformers match { - case Nil ⇒ Nil - case Seq(one) ⇒ Vector(input.via(one)) - case multiple ⇒ - val (fanoutSub, fanoutPub) = Source.asSubscriber[ByteString].toMat(Sink.asPublisher(true))(Keep.both).run() - val sources = transformers.map { flow ⇒ - // Doubly wrap to ensure that subscription to the running publisher happens before the final sources - // are exposed, so there is no race - Source.fromPublisher(Source.fromPublisher(fanoutPub).viaMat(flow)(Keep.right).runWith(Sink.asPublisher(false))) - } - // The fanout publisher must be wired to the original source after all fanout subscribers have been subscribed - input.runWith(Sink.fromSubscriber(fanoutSub)) - sources - } - def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity = _.transformDataBytes(mapErrorTransformer(f)) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala index 2cbdd5b0cd..e050999a88 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala @@ -16,7 +16,7 @@ import akka.event.{ NoLogging, LoggingAdapter } import akka.stream.impl.ConstantFun import akka.stream.Materializer import akka.stream.javadsl.{ Source ⇒ JSource } -import akka.stream.scaladsl.Source +import akka.stream.scaladsl._ import akka.http.scaladsl.util.FastFuture import akka.http.scaladsl.model.headers._ import akka.http.impl.engine.rendering.BodyPartRenderer @@ -187,10 +187,7 @@ object Multipart { private def strictify[BP <: Multipart.BodyPart, BPS <: Multipart.BodyPart.Strict](parts: Source[BP, Any])(f: BP ⇒ Future[BPS])(implicit fm: Materializer): Future[Vector[BPS]] = { import fm.executionContext - // TODO: move to Vector `:+` when https://issues.scala-lang.org/browse/SI-8930 is fixed - parts.runFold(new VectorBuilder[Future[BPS]]) { - case (builder, part) ⇒ builder += f(part) - }.fast.flatMap(builder ⇒ FastFuture.sequence(builder.result())) + parts.mapAsync(Int.MaxValue)(f).runWith(Sink.seq).fast.map(_.toVector) } //////////////////////// CONCRETE multipart types ///////////////////////// @@ -574,4 +571,4 @@ object Multipart { } } } -} \ No newline at end of file +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala index ae88bb343c..ff2a9b3961 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala @@ -26,6 +26,7 @@ import akka.http.scaladsl.model._ class ConnectionPoolSpec extends AkkaSpec(""" akka.loggers = [] akka.loglevel = OFF + akka.io.tcp.windows-connection-abort-workaround-enabled = auto akka.io.tcp.trace-logging = off""") { implicit val materializer = ActorMaterializer() diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index 8c172608f4..ef2f72b147 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -104,6 +104,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. sendWireData("0\n\n") sub.request(1) probe.expectNext(HttpEntity.LastChunk) + sub.request(1) probe.expectComplete() requestsSub.sendComplete() @@ -165,6 +166,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. sendWireData("0\n\n") sub.request(1) probe.expectNext(HttpEntity.LastChunk) + sub.request(1) probe.expectComplete() // simulate that response is received before method bypass reaches response parser diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala index dd875a3ca4..23c69a1529 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala @@ -26,7 +26,7 @@ class TlsEndpointVerificationSpec extends AkkaSpec(""" val timeout = Timeout(Span(3, Seconds)) "The client implementation" should { - "not accept certificates signed by unknown CA" in EventFilter[SSLException](occurrences = 1).intercept { + "not accept certificates signed by unknown CA" in { val pipe = pipeline(Http().defaultClientHttpsContext, hostname = "akka.example.org") // default context doesn't include custom CA whenReady(pipe(HttpRequest(uri = "https://akka.example.org/")).failed, timeout) { e ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala index ae45bc59cb..631fc039af 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala @@ -4,6 +4,7 @@ package akka.http.impl.engine.parsing +import akka.stream.impl.fusing.GraphInterpreter import com.typesafe.config.{ Config, ConfigFactory } import scala.concurrent.Future import scala.concurrent.duration._ @@ -23,7 +24,7 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.util.FastFuture import akka.http.scaladsl.util.FastFuture._ -import akka.stream.ActorMaterializer +import akka.stream.{ OverflowStrategy, ActorMaterializer } import akka.stream.scaladsl._ import akka.util.ByteString @@ -480,7 +481,9 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { case (Seq(RequestStart(method, uri, protocol, headers, createEntity, _, close)), entityParts) ⇒ closeAfterResponseCompletion :+= close Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) - case (Seq(x @ (MessageStartError(_, _) | EntityStreamError(_))), _) ⇒ Left(x) + case (Seq(x @ (MessageStartError(_, _) | EntityStreamError(_))), rest) ⇒ + rest.runWith(Sink.cancelled) + Left(x) } .concatSubstreams .flatMapConcat { x ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala index 4f5e3960fb..29a4aeab2c 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala @@ -298,7 +298,9 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { case (Seq(ResponseStart(statusCode, protocol, headers, createEntity, close)), entityParts) ⇒ closeAfterResponseCompletion :+= close Right(HttpResponse(statusCode, headers, createEntity(entityParts), protocol)) - case (Seq(x @ (MessageStartError(_, _) | EntityStreamError(_))), _) ⇒ Left(x) + case (Seq(x @ (MessageStartError(_, _) | EntityStreamError(_))), tail) ⇒ + tail.runWith(Sink.ignore) + Left(x) }.concatSubstreams def collectBlocking[T](source: Source[T, Any]): Seq[T] = diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index ad6687a147..007a5469a0 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -66,6 +66,8 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") + requests.request(1) + expectResponseWithWipedDate( """HTTP/1.1 505 HTTP Version Not Supported |Server: akka-http/test @@ -504,6 +506,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") dataProbe.expectNoMsg(50.millis) send("0123456789ABCDEF") dataProbe.expectNext(ByteString("0123456789ABCDEF")) + dataSub.request(1) dataProbe.expectComplete() responses.sendNext(HttpResponse(entity = "Yeah")) expectResponseWithWipedDate( @@ -545,6 +548,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |""") dataProbe.expectNext(Chunk(ByteString("0123456789ABCDEF"))) dataProbe.expectNext(LastChunk) + dataSub.request(1) dataProbe.expectComplete() responses.sendNext(HttpResponse(entity = "Yeah")) expectResponseWithWipedDate( @@ -663,6 +667,8 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") + requests.request(1) + expectResponseWithWipedDate( """|HTTP/1.1 400 Bad Request |Server: akka-http/test @@ -701,6 +707,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") val HttpRequest(POST, _, _, entity, _) = expectRequest() responses.sendNext(HttpResponse(status = StatusCodes.InsufficientStorage)) + entity.dataBytes.runWith(Sink.ignore) expectResponseWithWipedDate( """HTTP/1.1 507 Insufficient Storage diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index c0ed97a449..0dd471c12e 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -68,6 +68,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { val data2 = ByteString("def", "ASCII") pushInput(data2) sub.expectNext(data2) + s.request(1) sub.expectComplete() } @@ -87,6 +88,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { val data2 = ByteString("defg", "ASCII") pushInput(header2 ++ data2) sub.expectNext(data2) + s.request(1) sub.expectComplete() } "for several messages" in new ClientTestSetup { @@ -107,6 +109,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { val data3 = ByteString("h") pushInput(header2 ++ data2 ++ header3 ++ data3) sub.expectNext(data2) + s.request(1) sub.expectComplete() val dataSource2 = expectBinaryMessage().dataStream @@ -119,6 +122,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { val data4 = ByteString("i") pushInput(data4) sub2.expectNext(data4) + s2.request(1) sub2.expectComplete() } "unmask masked input on the server side" in new ServerTestSetup { @@ -138,6 +142,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { pushInput(data2) sub.expectNext(ByteString("def", "ASCII")) + s.request(1) sub.expectComplete() } "unmask masked input on the server side for empty frame" in new ServerTestSetup { @@ -218,6 +223,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { pushInput(data2) sub.expectNext(ByteString("cdef€", "UTF-8")) + s.request(1) sub.expectComplete() } "unmask masked input on the server side for empty frame" in new ServerTestSetup { @@ -430,6 +436,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1 pushInput(input2) sub.expectNext(ByteString("456", "ASCII")) + s.request(1) sub.expectComplete() } "don't respond to unsolicited pong frames" in new ClientTestSetup { @@ -770,6 +777,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { pushInput(frameHeader(Opcode.Text, 0, fin = false)) pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data) + // Kids, always drain your entities + messageIn.requestNext() match { + case b: TextMessage ⇒ + b.textStream.runWith(Sink.ignore) + case _ ⇒ + } + expectError(messageIn) expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) @@ -927,10 +941,12 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { def expectComplete[T](probe: TestSubscriber.Probe[T]): Unit = { probe.ensureSubscription() + probe.request(1) probe.expectComplete() } def expectError[T](probe: TestSubscriber.Probe[T]): Throwable = { probe.ensureSubscription() + probe.request(1) probe.expectError() } } diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala index ec9643da4c..f79657359a 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala @@ -33,6 +33,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll wit akka.loggers = ["akka.testkit.TestEventListener"] akka.loglevel = ERROR akka.stdout-loglevel = ERROR + windows-connection-abort-workaround-enabled = auto akka.log-dead-letters = OFF """) implicit val system = ActorSystem(getClass.getSimpleName, testConf) @@ -89,7 +90,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll wit } "properly terminate client when server is not running" in Utils.assertAllStagesStopped { - for (i ← 1 to 100) + for (i ← 1 to 10) withClue(s"iterator $i: ") { Source.single(HttpRequest(HttpMethods.POST, "/test", List.empty, HttpEntity(MediaTypes.`text/plain`.withCharset(HttpCharsets.`UTF-8`), "buh"))) .via(Http(actorSystem).outgoingConnection("localhost", 7777)) diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallersSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallersSpec.scala index 023e601397..ee29686d45 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallersSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallersSpec.scala @@ -305,7 +305,7 @@ class MultipartUnmarshallersSpec extends FreeSpec with Matchers with BeforeAndAf Await.result(x .fast.flatMap { _.parts - .mapAsync(1)(_ toStrict 1.second) + .mapAsync(Int.MaxValue)(_ toStrict 1.second) .grouped(100) .runWith(Sink.head) } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala index 9901ceccf9..c2dc41b51d 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala @@ -10,9 +10,11 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.RouteResult.Complete import akka.http.impl.util._ -import akka.stream.scaladsl.Source - +import akka.stream.scaladsl._ import scala.collection.immutable +import akka.util.ByteString +import akka.stream.SourceShape +import akka.stream.OverflowStrategy trait RangeDirectives { import akka.http.scaladsl.server.directives.BasicDirectives._ @@ -73,12 +75,31 @@ trait RangeDirectives { // Therefore, ranges need to be sorted to prevent that some selected ranges already start to accumulate data // but cannot be sent out because another range is blocking the queue. val coalescedRanges = coalesceRanges(iRanges).sortBy(_.start) - val bodyPartTransformers = coalescedRanges.map(ir ⇒ StreamUtils.sliceBytesTransformer(ir.start, ir.length)).toVector - val bodyPartByteStreams = StreamUtils.transformMultiple(entity.dataBytes, bodyPartTransformers) - val bodyParts = (coalescedRanges, bodyPartByteStreams).zipped.map { (range, bytes) ⇒ - Multipart.ByteRanges.BodyPart(range.contentRange(length), HttpEntity(entity.contentType, range.length, bytes)) + val source = coalescedRanges.size match { + case 0 ⇒ Source.empty + case 1 ⇒ + val range = coalescedRanges.head + val flow = StreamUtils.sliceBytesTransformer(range.start, range.length) + val bytes = entity.dataBytes.via(flow) + val part = Multipart.ByteRanges.BodyPart(range.contentRange(length), HttpEntity(entity.contentType, range.length, bytes)) + Source.single(part) + case n ⇒ + Source fromGraph GraphDSL.create() { implicit b ⇒ + import GraphDSL.Implicits._ + val bcast = b.add(Broadcast[ByteString](n)) + val merge = b.add(Concat[Multipart.ByteRanges.BodyPart](n)) + for (range ← coalescedRanges) { + val flow = StreamUtils.sliceBytesTransformer(range.start, range.length) + bcast ~> flow.buffer(16, OverflowStrategy.backpressure).prefixAndTail(0).map { + case (_, bytes) ⇒ + Multipart.ByteRanges.BodyPart(range.contentRange(length), HttpEntity(entity.contentType, range.length, bytes)) + } ~> merge + } + entity.dataBytes ~> bcast + SourceShape(merge.out) + } } - Multipart.ByteRanges(Source(bodyParts.toVector)) + Multipart.ByteRanges(source) } def rangeResponse(range: ByteRange, entity: UniversalEntity, length: Long, headers: immutable.Seq[HttpHeader]) = { @@ -133,4 +154,3 @@ object RangeDirectives extends RangeDirectives { private val respondWithAcceptByteRangesHeader: Directive0 = RespondWithDirectives.respondWithHeader(`Accept-Ranges`(RangeUnits.Bytes)) } - diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala index 53b60e011a..8410bca780 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala @@ -8,7 +8,8 @@ import scala.collection.immutable import scala.collection.immutable.VectorBuilder import akka.util.ByteString import akka.event.{ NoLogging, LoggingAdapter } -import akka.stream.impl.fusing.IteratorInterpreter +import akka.stream.OverflowStrategy +import akka.stream.impl.fusing.{ GraphInterpreter, IteratorInterpreter } import akka.stream.scaladsl._ import akka.http.impl.engine.parsing.BodyPartParser import akka.http.impl.util._ @@ -17,6 +18,7 @@ import akka.http.scaladsl.util.FastFuture import MediaRanges._ import MediaTypes._ import HttpCharsets._ +import akka.stream.impl.fusing.SubSource trait MultipartUnmarshallers { @@ -88,10 +90,14 @@ trait MultipartUnmarshallers { val bodyParts = entity.dataBytes .transform(() ⇒ parser) .splitWhen(_.isInstanceOf[PartStart]) + .buffer(100, OverflowStrategy.backpressure) // FIXME remove (#19240) .prefixAndTail(1) .collect { - case (Seq(BodyPartStart(headers, createEntity)), entityParts) ⇒ createBodyPart(createEntity(entityParts), headers) - case (Seq(ParseError(errorInfo)), _) ⇒ throw ParsingException(errorInfo) + case (Seq(BodyPartStart(headers, createEntity)), entityParts) ⇒ + createBodyPart(createEntity(entityParts), headers) + case (Seq(ParseError(errorInfo)), rest) ⇒ + SubSource.kill(rest) + throw ParsingException(errorInfo) } .concatSubstreams createStreamed(entity.contentType.mediaType.asInstanceOf[MediaType.Multipart], bodyParts) diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/Utils.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/Utils.scala index 341b55ed52..561df33d19 100644 --- a/akka-stream-testkit/src/test/scala/akka/stream/testkit/Utils.scala +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/Utils.scala @@ -25,12 +25,18 @@ object Utils { probe.expectMsg(StreamSupervisor.StoppedChildren) val result = block probe.within(5.seconds) { - probe.awaitAssert { + var children = Set.empty[ActorRef] + try probe.awaitAssert { impl.supervisor.tell(StreamSupervisor.GetChildren, probe.ref) - val children = probe.expectMsgType[StreamSupervisor.Children].children + children = probe.expectMsgType[StreamSupervisor.Children].children assert(children.isEmpty, s"expected no StreamSupervisor children, but got [${children.mkString(", ")}]") } + catch { + case ex: Throwable ⇒ + children.foreach(_ ! StreamSupervisor.PrintDebugDump) + throw ex + } } result case _ ⇒ block diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala index e0be7ba52f..74c0276647 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/KeepGoingStageSpec.scala @@ -48,6 +48,7 @@ class KeepGoingStageSpec extends AkkaSpec { private var listener: Option[ActorRef] = None override def preStart(): Unit = { + setKeepGoing(keepAlive) promise.trySuccess(PingRef(getAsyncCallback(onCommand))) } @@ -73,8 +74,6 @@ class KeepGoingStageSpec extends AkkaSpec { override def onUpstreamFinish(): Unit = listener.foreach(_ ! UpstreamCompleted) }) - override def keepGoingAfterAllPortsClosed: Boolean = keepAlive - override def postStop(): Unit = listener.foreach(_ ! PostStop) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala index 7c4881d9eb..f853248862 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowConcatAllSpec.scala @@ -43,7 +43,7 @@ class FlowConcatAllSpec extends AkkaSpec { } "work together with SplitWhen" in { - val subscriber = TestSubscriber.manualProbe[Int]() + val subscriber = TestSubscriber.probe[Int]() Source(1 to 10) .splitWhen(_ % 2 == 0) .prefixAndTail(0) @@ -51,11 +51,11 @@ class FlowConcatAllSpec extends AkkaSpec { .concatSubstreams .flatMapConcat(ConstantFun.scalaIdentityFunction) .runWith(Sink.fromSubscriber(subscriber)) - val subscription = subscriber.expectSubscription() - subscription.request(10) - for (i ← (1 to 10)) - subscriber.expectNext() shouldBe i - subscription.request(1) + + for (i ← 1 to 10) + subscriber.requestNext() shouldBe i + + subscriber.request(1) subscriber.expectComplete() } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala index 331d0c9048..97c029edd8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala @@ -86,7 +86,7 @@ class FlowFlattenMergeSpec extends AkkaSpec with ScalaFutures with ConversionChe "bubble up substream exceptions" in { val ex = new Exception("buh") - intercept[TestFailedException] { + val result = intercept[TestFailedException] { Source(List(blocked, blocked, Source.failed(ex))) .flatMapMerge(10, identity) .runWith(Sink.head) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala index 7f98e9365a..753550e093 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowPrefixAndTailSpec.scala @@ -101,7 +101,7 @@ class FlowPrefixAndTailSpec extends AkkaSpec { val subscriber2 = TestSubscriber.probe[Int]() tail.to(Sink.fromSubscriber(subscriber2)).run() - subscriber2.expectSubscriptionAndError().getMessage should ===("Tail Source cannot be materialized more than once.") + subscriber2.expectSubscriptionAndError().getMessage should ===("Substream Source cannot be materialized more than once") subscriber1.requestNext(2).expectComplete() @@ -122,7 +122,7 @@ class FlowPrefixAndTailSpec extends AkkaSpec { Thread.sleep(1000) tail.to(Sink.fromSubscriber(subscriber)).run()(tightTimeoutMaterializer) - subscriber.expectSubscriptionAndError().getMessage should ===("Tail Source has not been materialized in 500 milliseconds.") + subscriber.expectSubscriptionAndError().getMessage should ===("Substream Source has not been materialized in 500 milliseconds") } "shut down main stage if substream is empty, even when not subscribed" in assertAllStagesStopped { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitAfterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitAfterSpec.scala index 86ecc240ed..ca14b2c9ee 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitAfterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitAfterSpec.scala @@ -12,6 +12,7 @@ import akka.stream.testkit.TestPublisher import akka.stream.testkit.TestSubscriber import akka.stream.testkit.Utils._ import org.reactivestreams.Publisher +import scala.concurrent.Await import scala.concurrent.duration._ import akka.stream.StreamSubscriptionTimeoutSettings import akka.stream.StreamSubscriptionTimeoutTerminationMode @@ -115,6 +116,14 @@ class FlowSplitAfterSpec extends AkkaSpec { } } + "work with single elem splits" in assertAllStagesStopped { + Await.result( + Source(1 to 10).splitAfter(_ ⇒ true).lift + .mapAsync(1)(_.runWith(Sink.head)) // Please note that this line *also* implicitly asserts nonempty substreams + .grouped(10).runWith(Sink.head), + 3.second) should ===(1 to 10) + } + "support cancelling substreams" in assertAllStagesStopped { new SubstreamsSupport(splitAfter = 5, elementCount = 8) { val s1 = StreamPuppet(expectSubFlow().runWith(Sink.asPublisher(false))) @@ -181,6 +190,8 @@ class FlowSplitAfterSpec extends AkkaSpec { } "resume stream when splitAfter function throws" in assertAllStagesStopped { + info("Supervision is not supported fully by GraphStages yet") + pending val publisherProbeProbe = TestPublisher.manualProbe[Int]() val exc = TE("test") val publisher = Source.fromPublisher(publisherProbeProbe) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala index 8b9667dbf3..dd5423e77f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSplitWhenSpec.scala @@ -3,16 +3,14 @@ */ package akka.stream.scaladsl -import akka.stream.ActorMaterializer -import akka.stream.ActorMaterializerSettings -import akka.stream.ActorAttributes +import akka.stream._ import akka.stream.Supervision.resumingDecider +import akka.stream.impl.SubscriptionTimeoutException import akka.stream.testkit.Utils._ import akka.stream.testkit._ import org.reactivestreams.Publisher +import scala.concurrent.Await import scala.concurrent.duration._ -import akka.stream.StreamSubscriptionTimeoutSettings -import akka.stream.StreamSubscriptionTimeoutTerminationMode class FlowSplitWhenSpec extends AkkaSpec { import FlowSplitAfterSpec._ @@ -85,6 +83,16 @@ class FlowSplitWhenSpec extends AkkaSpec { } } + "not emit substreams if the parent stream is empty" in assertAllStagesStopped { + + Await.result( + Source.empty[Int] + .splitWhen(_ ⇒ true).lift + .mapAsync(1)(_.runWith(Sink.headOption)).grouped(10).runWith(Sink.headOption), + 3.seconds) should ===(None) // rather tricky way of saying that no empty substream should be emitted (vs. Some(None)) + + } + "work when first element is split-by" in assertAllStagesStopped { new SubstreamsSupport(1, elementCount = 3) { val s1 = StreamPuppet(getSubFlow().runWith(Sink.asPublisher(false))) @@ -137,7 +145,7 @@ class FlowSplitWhenSpec extends AkkaSpec { substream.cancel() masterStream.expectNext(()) - masterStream.expectNoMsg(1.second) + masterStream.expectNoMsg(100.millis) masterStream.cancel() inputs.expectCancellation() @@ -173,126 +181,166 @@ class FlowSplitWhenSpec extends AkkaSpec { src2.runWith(Sink.fromSubscriber(substream4)) substream4.requestNext(2) - substream4.expectNoMsg(1.second) - masterStream3.expectNoMsg(1.second) + substream4.expectNoMsg(100.millis) + masterStream3.expectNoMsg(100.millis) inputs3.expectRequest() inputs3.expectRequest() - inputs3.expectNoMsg(1.second) + inputs3.expectNoMsg(100.millis) substream4.cancel() - inputs3.expectNoMsg(1.second) - masterStream3.expectNoMsg(1.second) + inputs3.expectNoMsg(100.millis) + masterStream3.expectNoMsg(100.millis) masterStream3.cancel() inputs3.expectCancellation() } - } - "support cancelling the master stream" in assertAllStagesStopped { - new SubstreamsSupport(splitWhen = 5, elementCount = 8) { - val s1 = StreamPuppet(getSubFlow().runWith(Sink.asPublisher(false))) - masterSubscription.cancel() - s1.request(4) - s1.expectNext(1) - s1.expectNext(2) - s1.expectNext(3) - s1.expectNext(4) - s1.request(1) - s1.expectComplete() + "support cancelling the master stream" in assertAllStagesStopped { + new SubstreamsSupport(splitWhen = 5, elementCount = 8) { + val s1 = StreamPuppet(getSubFlow().runWith(Sink.asPublisher(false))) + masterSubscription.cancel() + s1.request(4) + s1.expectNext(1) + s1.expectNext(2) + s1.expectNext(3) + s1.expectNext(4) + s1.request(1) + s1.expectComplete() + } } - } - "fail stream when splitWhen function throws" in assertAllStagesStopped { - val publisherProbeProbe = TestPublisher.manualProbe[Int]() - val exc = TE("test") - val publisher = Source.fromPublisher(publisherProbeProbe) - .splitWhen(elem ⇒ if (elem == 3) throw exc else elem % 3 == 0) - .lift - .runWith(Sink.asPublisher(false)) - val subscriber = TestSubscriber.manualProbe[Source[Int, Unit]]() - publisher.subscribe(subscriber) + "fail stream when splitWhen function throws" in assertAllStagesStopped { + val publisherProbeProbe = TestPublisher.manualProbe[Int]() + val exc = TE("test") + val publisher = Source.fromPublisher(publisherProbeProbe) + .splitWhen(elem ⇒ if (elem == 3) throw exc else elem % 3 == 0) + .lift + .runWith(Sink.asPublisher(false)) + val subscriber = TestSubscriber.manualProbe[Source[Int, Unit]]() + publisher.subscribe(subscriber) - val upstreamSubscription = publisherProbeProbe.expectSubscription() + val upstreamSubscription = publisherProbeProbe.expectSubscription() - val downstreamSubscription = subscriber.expectSubscription() - downstreamSubscription.request(100) + val downstreamSubscription = subscriber.expectSubscription() + downstreamSubscription.request(100) - upstreamSubscription.sendNext(1) + upstreamSubscription.sendNext(1) - val substream = subscriber.expectNext() - val substreamPuppet = StreamPuppet(substream.runWith(Sink.asPublisher(false))) + val substream = subscriber.expectNext() + val substreamPuppet = StreamPuppet(substream.runWith(Sink.asPublisher(false))) - substreamPuppet.request(10) - substreamPuppet.expectNext(1) + substreamPuppet.request(10) + substreamPuppet.expectNext(1) - upstreamSubscription.sendNext(2) - substreamPuppet.expectNext(2) + upstreamSubscription.sendNext(2) + substreamPuppet.expectNext(2) - upstreamSubscription.sendNext(3) + upstreamSubscription.sendNext(3) - subscriber.expectError(exc) - substreamPuppet.expectError(exc) - upstreamSubscription.expectCancellation() - } + subscriber.expectError(exc) + substreamPuppet.expectError(exc) + upstreamSubscription.expectCancellation() + } - "resume stream when splitWhen function throws" in assertAllStagesStopped { - val publisherProbeProbe = TestPublisher.manualProbe[Int]() - val exc = TE("test") - val publisher = Source.fromPublisher(publisherProbeProbe) - .splitWhen(elem ⇒ if (elem == 3) throw exc else elem % 3 == 0) - .lift - .withAttributes(ActorAttributes.supervisionStrategy(resumingDecider)) - .runWith(Sink.asPublisher(false)) - val subscriber = TestSubscriber.manualProbe[Source[Int, Unit]]() - publisher.subscribe(subscriber) + "work with single elem splits" in assertAllStagesStopped { + Await.result( + Source(1 to 100).splitWhen(_ ⇒ true).lift + .mapAsync(1)(_.runWith(Sink.head)) // Please note that this line *also* implicitly asserts nonempty substreams + .grouped(200).runWith(Sink.head), + 3.second) should ===(1 to 100) + } - val upstreamSubscription = publisherProbeProbe.expectSubscription() + "fail substream if materialized twice" in assertAllStagesStopped { + an[IllegalStateException] mustBe thrownBy { + Await.result( + Source.single(1).splitWhen(_ ⇒ true).lift + .mapAsync(1) { src ⇒ src.runWith(Sink.ignore); src.runWith(Sink.ignore) } // Sink.ignore+mapAsync pipes error back + .runWith(Sink.ignore), + 3.seconds) + } + } - val downstreamSubscription = subscriber.expectSubscription() - downstreamSubscription.request(100) + "fail stream if substream not materialized in time" in assertAllStagesStopped { + val tightTimeoutMaterializer = + ActorMaterializer(ActorMaterializerSettings(system) + .withSubscriptionTimeoutSettings( + StreamSubscriptionTimeoutSettings(StreamSubscriptionTimeoutTerminationMode.cancel, 500.millisecond))) - upstreamSubscription.sendNext(1) + val testSource = Source.single(1).concat(Source.maybe).splitWhen(_ ⇒ true) - val substream1 = subscriber.expectNext() - val substreamPuppet1 = StreamPuppet(substream1.runWith(Sink.asPublisher(false))) + a[SubscriptionTimeoutException] mustBe thrownBy { + Await.result( + testSource.lift + .delay(1.second) + .flatMapConcat(identity) + .runWith(Sink.ignore)(tightTimeoutMaterializer), + 3.seconds) + } + } - substreamPuppet1.request(10) - substreamPuppet1.expectNext(1) + "resume stream when splitWhen function throws" in assertAllStagesStopped { + info("Supervision is not supported fully by GraphStages yet") + pending - upstreamSubscription.sendNext(2) - substreamPuppet1.expectNext(2) + val publisherProbeProbe = TestPublisher.manualProbe[Int]() + val exc = TE("test") + val publisher = Source.fromPublisher(publisherProbeProbe) + .splitWhen(elem ⇒ if (elem == 3) throw exc else elem % 3 == 0) + .lift + .withAttributes(ActorAttributes.supervisionStrategy(resumingDecider)) + .runWith(Sink.asPublisher(false)) + val subscriber = TestSubscriber.manualProbe[Source[Int, Unit]]() + publisher.subscribe(subscriber) - upstreamSubscription.sendNext(3) - upstreamSubscription.sendNext(4) - substreamPuppet1.expectNext(4) // note that 3 was dropped + val upstreamSubscription = publisherProbeProbe.expectSubscription() - upstreamSubscription.sendNext(5) - substreamPuppet1.expectNext(5) + val downstreamSubscription = subscriber.expectSubscription() + downstreamSubscription.request(100) - upstreamSubscription.sendNext(6) - substreamPuppet1.expectComplete() - val substream2 = subscriber.expectNext() - val substreamPuppet2 = StreamPuppet(substream2.runWith(Sink.asPublisher(false))) - substreamPuppet2.request(10) - substreamPuppet2.expectNext(6) + upstreamSubscription.sendNext(1) - upstreamSubscription.sendComplete() - subscriber.expectComplete() - substreamPuppet2.expectComplete() - } + val substream1 = subscriber.expectNext() + val substreamPuppet1 = StreamPuppet(substream1.runWith(Sink.asPublisher(false))) - "pass along early cancellation" in assertAllStagesStopped { - val up = TestPublisher.manualProbe[Int]() - val down = TestSubscriber.manualProbe[Source[Int, Unit]]() + substreamPuppet1.request(10) + substreamPuppet1.expectNext(1) - val flowSubscriber = Source.asSubscriber[Int].splitWhen(_ % 3 == 0).lift.to(Sink.fromSubscriber(down)).run() + upstreamSubscription.sendNext(2) + substreamPuppet1.expectNext(2) + + upstreamSubscription.sendNext(3) + upstreamSubscription.sendNext(4) + substreamPuppet1.expectNext(4) // note that 3 was dropped + + upstreamSubscription.sendNext(5) + substreamPuppet1.expectNext(5) + + upstreamSubscription.sendNext(6) + substreamPuppet1.expectComplete() + val substream2 = subscriber.expectNext() + val substreamPuppet2 = StreamPuppet(substream2.runWith(Sink.asPublisher(false))) + substreamPuppet2.request(10) + substreamPuppet2.expectNext(6) + + upstreamSubscription.sendComplete() + subscriber.expectComplete() + substreamPuppet2.expectComplete() + } + + "pass along early cancellation" in assertAllStagesStopped { + val up = TestPublisher.manualProbe[Int]() + val down = TestSubscriber.manualProbe[Source[Int, Unit]]() + + val flowSubscriber = Source.asSubscriber[Int].splitWhen(_ % 3 == 0).lift.to(Sink.fromSubscriber(down)).run() + + val downstream = down.expectSubscription() + downstream.cancel() + up.subscribe(flowSubscriber) + val upsub = up.expectSubscription() + upsub.expectCancellation() + } - val downstream = down.expectSubscription() - downstream.cancel() - up.subscribe(flowSubscriber) - val upsub = up.expectSubscription() - upsub.expectCancellation() } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala index b5870d9f3c..875a265702 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala @@ -84,7 +84,7 @@ class SinkForeachParallelSpec extends AkkaSpec { }).withAttributes(supervisionStrategy(resumingDecider))) latch.countDown() - probe.expectMsgAllOf(1, 2, 4) + probe.expectMsgAllOf(1, 2, 4, 5) Await.result(p, 5.seconds) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala index 54d36bb7b1..f028a5c8cc 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala @@ -72,7 +72,7 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { val (_, s3) = subscriber.expectNext() // sleep long enough for it to be cleaned up - Thread.sleep(1000) + Thread.sleep(1500) val f = s3.runWith(Sink.head).recover { case _: SubscriptionTimeoutException ⇒ "expected" } Await.result(f, 300.millis) should equal("expected") diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index 6f352746fa..db6091ca57 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -8,6 +8,7 @@ import scala.annotation.tailrec import scala.reflect.{ classTag, ClassTag } import akka.japi.function import akka.stream.impl.StreamLayout._ +import java.net.URLEncoder /** * Holds attributes which can be used to alter [[akka.stream.scaladsl.Flow]] / [[akka.stream.javadsl.Flow]] @@ -142,11 +143,12 @@ final case class Attributes(attributeList: List[Attributes.Attribute] = Nil) { if (i.hasNext) i.next() match { case Name(n) ⇒ - if (buf ne null) concatNames(i, null, buf.append('-').append(n)) + val nn = URLEncoder.encode(n, "UTF-8") + if (buf ne null) concatNames(i, null, buf.append('-').append(nn)) else if (first ne null) { - val b = new StringBuilder((first.length() + n.length()) * 2) - concatNames(i, null, b.append(first).append('-').append(n)) - } else concatNames(i, n, null) + val b = new StringBuilder((first.length() + nn.length()) * 2) + concatNames(i, null, b.append(first).append('-').append(nn)) + } else concatNames(i, nn, null) case _ ⇒ concatNames(i, first, buf) } else if (buf eq null) first diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index 4e946fe6d6..f9360aed43 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -270,6 +270,8 @@ private[akka] object StreamSupervisor { case object StopChildren /** Testing purpose */ case object StoppedChildren + /** Testing purpose */ + case object PrintDebugDump } private[akka] class StreamSupervisor(settings: ActorMaterializerSettings, haveShutDown: AtomicBoolean) extends Actor { @@ -303,7 +305,6 @@ private[akka] object ActorProcessorFactory { val settings = materializer.effectiveSettings(att) op match { case GroupBy(maxSubstreams, f, _) ⇒ (GroupByProcessorImpl.props(settings, maxSubstreams, f), ()) - case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) case DirectProcessor(p, m) ⇒ throw new AssertionError("DirectProcessor cannot end up in ActorProcessorFactory") } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala index 9391ccc0e8..8e7aa10f32 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala @@ -34,8 +34,6 @@ private[akka] class ActorRefBackpressureSinkStage[In](ref: ActorRef, onInitMessa var acknowledgementReceived = false var completeReceived = false - override def keepGoingAfterAllPortsClosed: Boolean = true - private def receive(evt: (ActorRef, Any)): Unit = { evt._2 match { case `ackMessage` ⇒ @@ -47,6 +45,7 @@ private[akka] class ActorRefBackpressureSinkStage[In](ref: ActorRef, onInitMessa } override def preStart() = { + setKeepGoing(true) self = getStageActorRef(receive) self.watch(ref) ref ! onInitMessage diff --git a/akka-stream/src/main/scala/akka/stream/impl/BoundedBuffer.scala b/akka-stream/src/main/scala/akka/stream/impl/BoundedBuffer.scala new file mode 100644 index 0000000000..ab50babe0b --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/BoundedBuffer.scala @@ -0,0 +1,104 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl + +import java.{ util ⇒ ju } + +/** + * INTERNAL API + */ +private[akka] trait Buffer[T] { + def used: Int + def isFull: Boolean + def isEmpty: Boolean + def nonEmpty: Boolean + + def enqueue(elem: T): Unit + def dequeue(): T + + def peek(): T + def clear(): Unit + def dropHead(): Unit + def dropTail(): Unit +} + +/** + * INTERNAL API + */ +private[akka] final class BoundedBuffer[T](val capacity: Int) extends Buffer[T] { + + def used: Int = q.used + def isFull: Boolean = q.isFull + def isEmpty: Boolean = q.isEmpty + def nonEmpty: Boolean = q.nonEmpty + + def enqueue(elem: T): Unit = q.enqueue(elem) + def dequeue(): T = q.dequeue() + + def peek(): T = q.peek() + def clear(): Unit = q.clear() + def dropHead(): Unit = q.dropHead() + def dropTail(): Unit = q.dropTail() + + private final class FixedQueue extends Buffer[T] { + final val Size = 16 + final val Mask = 15 + + private val queue = new Array[AnyRef](Size) + private var head = 0 + private var tail = 0 + + override def used = tail - head + override def isFull = used == capacity + override def isEmpty = tail == head + override def nonEmpty = tail != head + + override def enqueue(elem: T): Unit = + if (tail - head == Size) { + val queue = new DynamicQueue(head) + while (nonEmpty) { + queue.enqueue(dequeue()) + } + q = queue + queue.enqueue(elem) + } else { + queue(tail & Mask) = elem.asInstanceOf[AnyRef] + tail += 1 + } + override def dequeue(): T = { + val pos = head & Mask + val ret = queue(pos).asInstanceOf[T] + queue(pos) = null + head += 1 + ret + } + + override def peek(): T = + if (tail == head) null.asInstanceOf[T] + else queue(head & Mask).asInstanceOf[T] + override def clear(): Unit = + while (nonEmpty) { + dequeue() + } + override def dropHead(): Unit = dequeue() + override def dropTail(): Unit = { + tail -= 1 + queue(tail & Mask) = null + } + } + + private final class DynamicQueue(startIdx: Int) extends ju.LinkedList[T] with Buffer[T] { + override def used = size + override def isFull = size == capacity + override def nonEmpty = !isEmpty() + + override def enqueue(elem: T): Unit = add(elem) + override def dequeue(): T = remove() + + override def dropHead(): Unit = remove() + override def dropTail(): Unit = removeLast() + } + + private var q: Buffer[T] = new FixedQueue +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala index c46b2db8d0..8ac3dcbe7c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala @@ -25,7 +25,7 @@ private[akka] object FixedSizeBuffer { else if (((size - 1) & size) == 0) new PowerOfTwoFixedSizeBuffer(size) else new ModuloFixedSizeBuffer(size) - sealed abstract class FixedSizeBuffer[T](val size: Int) { + sealed abstract class FixedSizeBuffer[T](val size: Int) extends Buffer[T] { override def toString = s"Buffer($size, $readIdx, $writeIdx)(${(readIdx until writeIdx).map(get).mkString(", ")})" private val buffer = new Array[AnyRef](size) @@ -35,12 +35,11 @@ private[akka] object FixedSizeBuffer { def isFull: Boolean = used == size def isEmpty: Boolean = used == 0 + def nonEmpty: Boolean = used != 0 - def enqueue(elem: T): Int = { + def enqueue(elem: T): Unit = { put(writeIdx, elem) - val ret = writeIdx writeIdx += 1 - ret } protected def toOffset(idx: Int): Int diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index 9cb592d024..2bc7da30ed 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -261,9 +261,9 @@ private[akka] class QueueSink[T]() extends GraphStageWithMaterializedValue[SinkS var currentRequest: Option[Requested[T]] = None val stageLogic = new GraphStageLogic(shape) with RequestElementCallback[Requested[T]] { - override def keepGoingAfterAllPortsClosed = true override def preStart(): Unit = { + setKeepGoing(true) val list = requestElement.getAndSet(callback.invoke _).asInstanceOf[List[Requested[T]]] list.reverse.foreach(callback.invoke) pull(in) diff --git a/akka-stream/src/main/scala/akka/stream/impl/SplitWhereProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/SplitWhereProcessorImpl.scala deleted file mode 100644 index a9cc9bfb56..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/SplitWhereProcessorImpl.scala +++ /dev/null @@ -1,153 +0,0 @@ -/** - * Copyright (C) 2009-2015 Typesafe Inc. - */ -package akka.stream.impl - -import akka.actor.{ Deploy, Props } -import akka.stream.impl.SplitDecision.SplitDecision -import akka.stream.scaladsl.Source -import akka.stream.{ ActorMaterializerSettings, Supervision } - -import scala.util.control.NonFatal - -/** INTERNAL API */ -private[akka] object SplitDecision { - sealed abstract class SplitDecision - - /** Splits before the current element. The current element will be the first element in the new substream. */ - case object SplitBefore extends SplitDecision - - /** Splits after the current element. The current element will be the last element in the current substream. */ - case object SplitAfter extends SplitDecision - - /** Emit this element into the current substream. */ - case object Continue extends SplitDecision - - /** - * Drop this element without signalling it to any substream. - * TODO: Dropping is currently not exposed in an usable way - we would have to expose splitWhen(x => SplitDecision), to be decided if we want this - */ - private[impl] case object Drop extends SplitDecision -} - -/** - * INTERNAL API - */ -private[akka] object SplitWhereProcessorImpl { - def props(settings: ActorMaterializerSettings, splitPredicate: Any ⇒ SplitDecision): Props = - Props(new SplitWhereProcessorImpl(settings, in ⇒ splitPredicate(in))).withDeploy(Deploy.local) -} - -/** - * INTERNAL API - */ -private[akka] class SplitWhereProcessorImpl(_settings: ActorMaterializerSettings, val splitPredicate: Any ⇒ SplitDecision) - extends MultiStreamOutputProcessor(_settings) { - - import MultiStreamOutputProcessor._ - import SplitDecision._ - - /** - * `firstElement` is needed in case a SplitBefore is signalled, and the first element matches - * We do not want to emit an "empty stream" then followed by the "split", but we do want to start the stream - * from the first element (as if no split was applied): [0,1,2,0].splitWhen(_ == 0) => [0,1,2], [0] - */ - var firstElement = true - - val decider = settings.supervisionDecider - var currentSubstream: SubstreamOutput = _ - - val waitFirst = TransferPhase(primaryInputs.NeedsInput && primaryOutputs.NeedsDemand) { () ⇒ - val elem = primaryInputs.dequeueInputElement() - decideSplit(elem) match { - case Continue ⇒ nextPhase(openSubstream(serveSubstreamFirst(_, elem))) - case SplitAfter ⇒ nextPhase(openSubstream(completeSubstream(_, elem))) - case SplitBefore ⇒ nextPhase(openSubstream(serveSubstreamFirst(_, elem))) - case Drop ⇒ // stay in waitFirst - } - } - - private def openSubstream(andThen: SubstreamOutput ⇒ TransferPhase): TransferPhase = TransferPhase(primaryOutputs.NeedsDemand) { () ⇒ - val substreamOutput = createSubstreamOutput() - val substreamFlow = Source.fromPublisher(substreamOutput) - primaryOutputs.enqueueOutputElement(substreamFlow) - currentSubstream = substreamOutput - - nextPhase(andThen(currentSubstream)) - } - - // Serving the substream is split into two phases to minimize elements "held in hand" - private def serveSubstreamFirst(substream: SubstreamOutput, elem: Any) = TransferPhase(substream.NeedsDemand) { () ⇒ - firstElement = false - substream.enqueueOutputElement(elem) - nextPhase(serveSubstreamRest(substream)) - } - - // Signal given element to substream and complete it - private def completeSubstream(substream: SubstreamOutput, elem: Any): TransferPhase = TransferPhase(substream.NeedsDemand) { () ⇒ - substream.enqueueOutputElement(elem) - completeSubstreamOutput(currentSubstream.key) - nextPhase(waitFirst) - } - - // Note that this phase is allocated only once per _slice_ and not per element - private def serveSubstreamRest(substream: SubstreamOutput): TransferPhase = TransferPhase(primaryInputs.NeedsInput && substream.NeedsDemand) { () ⇒ - val elem = primaryInputs.dequeueInputElement() - decideSplit(elem) match { - case Continue ⇒ - substream.enqueueOutputElement(elem) - - case SplitAfter ⇒ - substream.enqueueOutputElement(elem) - completeSubstreamOutput(currentSubstream.key) - currentSubstream = null - nextPhase(openSubstream(serveSubstreamRest)) - - case SplitBefore if firstElement ⇒ - currentSubstream.enqueueOutputElement(elem) - completeSubstreamOutput(currentSubstream.key) - currentSubstream = null - nextPhase(openSubstream(serveSubstreamRest)) - - case SplitBefore ⇒ - completeSubstreamOutput(currentSubstream.key) - currentSubstream = null - nextPhase(openSubstream(serveSubstreamFirst(_, elem))) - - case Drop ⇒ - // drop elem and continue - } - firstElement = false - } - - // Ignore elements for a cancelled substream until a new substream needs to be opened - val ignoreUntilNewSubstream = TransferPhase(primaryInputs.NeedsInput && primaryOutputs.NeedsDemand) { () ⇒ - val elem = primaryInputs.dequeueInputElement() - decideSplit(elem) match { - case Continue | Drop ⇒ // ignore elem - case SplitBefore ⇒ nextPhase(openSubstream(serveSubstreamFirst(_, elem))) - case SplitAfter ⇒ nextPhase(openSubstream(serveSubstreamRest)) - } - } - - private def decideSplit(elem: Any): SplitDecision = - try splitPredicate(elem) catch { - case NonFatal(e) if decider(e) != Supervision.Stop ⇒ - if (settings.debugLogging) - log.debug("Dropped element [{}] due to exception from splitWhen function: {}", elem, e.getMessage) - Drop - } - - initialPhase(1, waitFirst) - - override def completeSubstreamOutput(substream: SubstreamKey): Unit = { - if ((currentSubstream ne null) && substream == currentSubstream.key) nextPhase(ignoreUntilNewSubstream) - super.completeSubstreamOutput(substream) - } - - override def cancelSubstreamOutput(substream: SubstreamKey): Unit = { - if ((currentSubstream ne null) && substream == currentSubstream.key) nextPhase(ignoreUntilNewSubstream) - super.cancelSubstreamOutput(substream) - } - -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index b513a40316..088ed69981 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -8,7 +8,6 @@ import akka.stream.ActorAttributes.SupervisionStrategy import akka.stream.Attributes._ import akka.stream.Supervision.Decider import akka.stream._ -import akka.stream.impl.SplitDecision.{ Continue, SplitAfter, SplitBefore, SplitDecision } import akka.stream.impl.StreamLayout._ import akka.stream.scaladsl.Source import akka.stream.stage.AbstractStage.PushPullGraphStage @@ -214,16 +213,6 @@ private[stream] object Stages { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } - final case class Split(p: Any ⇒ SplitDecision, attributes: Attributes = split) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - - object Split { - def when(f: Any ⇒ Boolean) = Split(el ⇒ if (f(el)) SplitBefore else Continue, name("splitWhen")) - - def after(f: Any ⇒ Boolean) = Split(el ⇒ if (f(el)) SplitAfter else Continue, name("splitAfter")) - } - final case class DirectProcessor(p: () ⇒ (Processor[Any, Any], Any), attributes: Attributes = processor) extends StageModule { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 6267d26e2b..5bb63b7a3e 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -18,6 +18,8 @@ import scala.util.control.NonFatal import akka.event.LoggingAdapter import akka.stream.impl.ActorMaterializerImpl import akka.stream.impl.SubFusingActorMaterializerImpl +import scala.annotation.tailrec +import akka.stream.impl.StreamSupervisor /** * INTERNAL API @@ -310,7 +312,7 @@ private[stream] final class GraphInterpreterShell( logics: Array[GraphStageLogic], shape: Shape, settings: ActorMaterializerSettings, - mat: ActorMaterializerImpl) { + val mat: ActorMaterializerImpl) { import ActorGraphInterpreter._ @@ -326,6 +328,8 @@ private[stream] final class GraphInterpreterShell( private var subscribesPending = inputs.length private var publishersPending = outputs.length + def dumpWaits(): Unit = interpreter.dumpWaits() + /* * Limits the number of events processed by the interpreter before scheduling * a self-message for fairness with other actors. The basic assumption here is @@ -339,7 +343,9 @@ private[stream] final class GraphInterpreterShell( private val abortLimit = eventLimit * 2 private var resumeScheduled = false - def init(self: ActorRef, registerShell: GraphInterpreterShell ⇒ ActorRef): Unit = { + def isInitialized: Boolean = self != null + + def init(self: ActorRef, subMat: SubFusingActorMaterializerImpl): Unit = { this.self = self var i = 0 while (i < inputs.length) { @@ -356,7 +362,7 @@ private[stream] final class GraphInterpreterShell( interpreter.attachDownstreamBoundary(i + offset, out) i += 1 } - interpreter.init(new SubFusingActorMaterializerImpl(mat, registerShell)) + interpreter.init(subMat) runBatch() } @@ -390,14 +396,7 @@ private[stream] final class GraphInterpreterShell( resumeScheduled = false if (interpreter.isSuspended) runBatch() case AsyncInput(_, logic, event, handler) ⇒ - if (GraphInterpreter.Debug) println(s"${interpreter.Name} ASYNC $event ($handler) [$logic]") - if (!interpreter.isStageCompleted(logic)) { - try handler(event) - catch { - case NonFatal(e) ⇒ logic.failStage(e) - } - interpreter.afterStageHasRun(logic) - } + interpreter.runAsyncInput(logic, event, handler) runBatch() // Initialization and completion messages @@ -493,31 +492,75 @@ private[stream] final class GraphInterpreterShell( /** * INTERNAL API */ -private[stream] class ActorGraphInterpreter(_initial: GraphInterpreterShell) extends Actor { +private[stream] class ActorGraphInterpreter(_initial: GraphInterpreterShell) extends Actor with ActorLogging { import ActorGraphInterpreter._ - var activeInterpreters = Set(_initial) + var activeInterpreters = Set.empty[GraphInterpreterShell] + var newShells: List[GraphInterpreterShell] = Nil + val subFusingMaterializerImpl = new SubFusingActorMaterializerImpl(_initial.mat, registerShell) + + def tryInit(shell: GraphInterpreterShell): Boolean = + try { + shell.init(self, subFusingMaterializerImpl) + if (GraphInterpreter.Debug) println(s"registering new shell in ${_initial}\n ${shell.toString.replace("\n", "\n ")}") + if (shell.isTerminated) false + else { + activeInterpreters += shell + true + } + } catch { + case NonFatal(e) ⇒ + log.error(e, "initialization of GraphInterpreterShell failed for {}", shell) + false + } def registerShell(shell: GraphInterpreterShell): ActorRef = { - shell.init(self, registerShell) - if (GraphInterpreter.Debug) println(s"registering new shell in ${_initial}\n ${shell.toString.replace("\n", "\n ")}") - activeInterpreters += shell + newShells ::= shell + self ! Resume self } + /* + * Avoid performing the initialization (which start the first runBatch()) + * within registerShell in order to avoid unbounded recursion. + */ + @tailrec private def finishShellRegistration(): Unit = + newShells match { + case Nil ⇒ if (activeInterpreters.isEmpty) context.stop(self) + case shell :: tail ⇒ + newShells = tail + if (shell.isInitialized) { + // yes, this steals another shell’s Resume, but that’s okay because extra ones will just not do anything + finishShellRegistration() + } else tryInit(shell) + } + override def preStart(): Unit = { - activeInterpreters.foreach(_.init(self, registerShell)) + tryInit(_initial) + if (activeInterpreters.isEmpty) context.stop(self) } override def receive: Receive = { case b: BoundaryEvent ⇒ val shell = b.shell - if (GraphInterpreter.Debug) - if (!activeInterpreters.contains(shell)) println(s"RECEIVED EVENT $b FOR UNKNOWN SHELL $shell") - shell.receive(b) - if (shell.isTerminated) { - activeInterpreters -= shell - if (activeInterpreters.isEmpty) context.stop(self) + if (!shell.isTerminated && (shell.isInitialized || tryInit(shell))) { + shell.receive(b) + if (shell.isTerminated) { + activeInterpreters -= shell + if (activeInterpreters.isEmpty && newShells.isEmpty) context.stop(self) + } + } + case Resume ⇒ finishShellRegistration() + case StreamSupervisor.PrintDebugDump ⇒ + println(s"activeShells:") + activeInterpreters.foreach { shell ⇒ + println(" " + shell.toString.replace("\n", "\n ")) + shell.interpreter.dumpWaits() + } + println(s"newShells:") + newShells.foreach { shell ⇒ + println(" " + shell.toString.replace("\n", "\n ")) + shell.interpreter.dumpWaits() } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala index 5032b13fb1..185f9ccb09 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala @@ -34,7 +34,7 @@ private[stream] object Fusing { * information in the BuildStructuralInfo instance. */ val matValue = - try descend(g.module, Attributes.none, struct, struct.newGroup(""), "") + try descend(g.module, Attributes.none, struct, struct.newGroup(0), 0) catch { case NonFatal(ex) ⇒ if (Debug) struct.dump() @@ -245,7 +245,7 @@ private[stream] object Fusing { inheritedAttributes: Attributes, struct: BuildStructuralInfo, openGroup: ju.Set[Module], - indent: String): List[(Module, MaterializedValueNode)] = { + indent: Int): List[(Module, MaterializedValueNode)] = { def log(msg: String): Unit = println(indent + msg) val async = m match { case _: GraphStageModule ⇒ m.attributes.contains(AsyncBoundary) @@ -262,9 +262,9 @@ private[stream] object Fusing { m match { case gm: GraphModule if !async ⇒ // need to dissolve previously fused GraphStages to allow further fusion - if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + indent)}") + if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + " " * indent)}") val attributes = inheritedAttributes and m.attributes - gm.matValIDs.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut) + gm.matValIDs.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + 1))(collection.breakOut) case gm @ GraphModule(_, oldShape, _, mvids) ⇒ /* * Importing a GraphModule that has an AsyncBoundary attribute is a little more work: @@ -337,7 +337,7 @@ private[stream] object Fusing { m match { case CopiedModule(shape, _, copyOf) ⇒ val ret = - descend(copyOf, attributes, struct, localGroup, indent + " ") match { + descend(copyOf, attributes, struct, localGroup, indent + 1) match { case xs @ (_, mat) :: _ ⇒ (m -> mat) :: xs case _ ⇒ throw new IllegalArgumentException("cannot happen") } @@ -348,8 +348,14 @@ private[stream] object Fusing { // computation context (i.e. that need the same value). struct.enterMatCtx() // now descend into submodules and collect their computations (plus updates to `struct`) - val subMat: Predef.Map[Module, MaterializedValueNode] = - m.subModules.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut) + val subMatBuilder = Predef.Map.newBuilder[Module, MaterializedValueNode] + val subIterator = m.subModules.iterator + while (subIterator.hasNext) { + val sub = subIterator.next() + val res = descend(sub, attributes, struct, localGroup, indent + 1) + subMatBuilder ++= res + } + val subMat = subMatBuilder.result() if (Debug) log(subMat.map(p ⇒ s"${p._1.getClass.getName}[${p._1.hashCode}] -> ${p._2}").mkString("subMat\n " + indent, "\n " + indent, "")) // we need to remove all wirings that this module copied from nested modules so that we // don’t do wirings twice @@ -553,8 +559,8 @@ private[stream] object Fusing { * connections within imported (and not dissolved) GraphModules. * See also the comment in addModule where this is partially undone. */ - def registerInteral(s: Shape, indent: String): Unit = { - if (Debug) println(indent + s"registerInternals(${s.outlets.map(hash)})") + def registerInteral(s: Shape, indent: Int): Unit = { + if (Debug) println(" " * indent + s"registerInternals(${s.outlets.map(hash)})") internalOuts.addAll(s.outlets.asJava) } @@ -585,9 +591,9 @@ private[stream] object Fusing { /** * Create and return a new grouping (i.e. an AsyncBoundary-delimited context) */ - def newGroup(indent: String): ju.Set[Module] = { + def newGroup(indent: Int): ju.Set[Module] = { val group = new ju.HashSet[Module] - if (Debug) println(indent + s"creating new group ${hash(group)}") + if (Debug) println(" " * indent + s"creating new group ${hash(group)}") groups.add(group) group } @@ -595,13 +601,13 @@ private[stream] object Fusing { /** * Add a module to the given group, performing normalization (i.e. giving it a unique port identity). */ - def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String, + def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: Int, _oldShape: Shape = null): Atomic = { val copy = if (_oldShape == null) CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m)) else m val oldShape = if (_oldShape == null) m.shape else _oldShape - if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}") + if (Debug) println(" " * indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}") group.add(copy) modules.add(copy) copy.shape.outlets.foreach(o ⇒ outGroup.put(o, group)) @@ -648,8 +654,8 @@ private[stream] object Fusing { * Record a wiring between two copied ports, using (and reducing) the port * mappings. */ - def wire(out: OutPort, in: InPort, indent: String): Unit = { - if (Debug) println(indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})") + def wire(out: OutPort, in: InPort, indent: Int): Unit = { + if (Debug) println(" " * indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})") val newOut = removeMapping(out, newOuts) nonNull s"$out (${hash(out)})" val newIn = removeMapping(in, newIns) nonNull s"$in (${hash(in)})" downstreams.put(newOut, newIn) @@ -659,8 +665,8 @@ private[stream] object Fusing { /** * Replace all mappings for a given shape with its new (copied) form. */ - def rewire(oldShape: Shape, newShape: Shape, indent: String): Unit = { - if (Debug) println(indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}") + def rewire(oldShape: Shape, newShape: Shape, indent: Int): Unit = { + if (Debug) println(" " * indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}") oldShape.inlets.iterator.zip(newShape.inlets.iterator).foreach { case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull s"$oldIn (${hash(oldIn)})", newIns) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index 3c653a2341..45215b6202 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -21,7 +21,7 @@ import akka.stream.impl.fusing.GraphStages.MaterializedValueSource * * (See the class for the documentation of the internals) */ -private[stream] object GraphInterpreter { +private[akka] object GraphInterpreter { /** * Compile time constant, enable it for debug logging to the console. */ @@ -44,6 +44,9 @@ private[stream] object GraphInterpreter { final val PushStartFlip = 12 //1100 final val PushEndFlip = 5 //0101 + final val KeepGoingFlag = 0x4000000 + final val KeepGoingMask = 0x3ffffff + /** * Marker object that indicates that a port holds no element since it was already grabbed. The port is still pullable, * but there is no more element to grab. @@ -250,14 +253,14 @@ private[stream] object GraphInterpreter { /** * INTERNAL API */ - private[stream] def currentInterpreter: GraphInterpreter = + private[akka] def currentInterpreter: GraphInterpreter = _currentInterpreter.get()(0).asInstanceOf[GraphInterpreter].nonNull // nonNull is just a debug helper to find nulls more timely /** * INTERNAL API */ - private[stream] def currentInterpreterOrNull: GraphInterpreter = + private[akka] def currentInterpreterOrNull: GraphInterpreter = _currentInterpreter.get()(0).asInstanceOf[GraphInterpreter] } @@ -368,8 +371,7 @@ private[stream] final class GraphInterpreter( // Counts how many active connections a stage has. Once it reaches zero, the stage is automatically stopped. private[this] val shutdownCounter = Array.tabulate(assembly.stages.length) { i ⇒ val shape = assembly.stages(i).shape - val keepGoing = if (logics(i).keepGoingAfterAllPortsClosed) 1 else 0 - shape.inlets.size + shape.outlets.size + keepGoing + shape.inlets.size + shape.outlets.size } private var _subFusingMaterializer: Materializer = _ @@ -512,12 +514,15 @@ private[stream] final class GraphInterpreter( case owner ⇒ logics(owner).toString } + private def shutdownCounters: String = + shutdownCounter.map(x ⇒ if (x >= KeepGoingFlag) s"${x & KeepGoingMask}(KeepGoing)" else x.toString).mkString(",") + /** * Executes pending events until the given limit is met. If there were remaining events, isSuspended will return * true. */ def execute(eventLimit: Int): Unit = { - if (Debug) println(s"$Name ---------------- EXECUTE (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") + if (Debug) println(s"$Name ---------------- EXECUTE $queueStatus (running=$runningStages, shutdown=$shutdownCounters)") val currentInterpreterHolder = _currentInterpreter.get() val previousInterpreter = currentInterpreterHolder(0) currentInterpreterHolder(0) = this @@ -537,10 +542,26 @@ private[stream] final class GraphInterpreter( } finally { currentInterpreterHolder(0) = previousInterpreter } - if (Debug) println(s"$Name ---------------- $queueStatus (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") + if (Debug) println(s"$Name ---------------- $queueStatus (running=$runningStages, shutdown=$shutdownCounters)") // TODO: deadlock detection } + def runAsyncInput(logic: GraphStageLogic, evt: Any, handler: (Any) ⇒ Unit): Unit = + if (!isStageCompleted(logic)) { + if (GraphInterpreter.Debug) println(s"$Name ASYNC $evt ($handler) [$logic]") + val currentInterpreterHolder = _currentInterpreter.get() + val previousInterpreter = currentInterpreterHolder(0) + currentInterpreterHolder(0) = this + try { + activeStage = logic + try handler(evt) + catch { + case NonFatal(ex) ⇒ logic.failStage(ex) + } + afterStageHasRun(logic) + } finally currentInterpreterHolder(0) = previousInterpreter + } + // Decodes and processes a single event for the given connection private def processEvent(connection: Int): Unit = { def safeLogics(id: Int) = @@ -638,11 +659,9 @@ private[stream] final class GraphInterpreter( } } - // Call only for keep-alive stages - def closeKeptAliveStageIfNeeded(stageId: Int): Unit = - if (stageId != Boundary && shutdownCounter(stageId) == 1) { - shutdownCounter(stageId) = 0 - } + private[stream] def setKeepGoing(logic: GraphStageLogic, enabled: Boolean): Unit = + if (enabled) shutdownCounter(logic.stageId) |= KeepGoingFlag + else shutdownCounter(logic.stageId) &= KeepGoingMask private def finalizeStage(logic: GraphStageLogic): Unit = { try { @@ -675,7 +694,7 @@ private[stream] final class GraphInterpreter( val currentState = portStates(connection) if (Debug) println(s"$Name complete($connection) [$currentState]") portStates(connection) = currentState | OutClosed - if ((currentState & (InClosed | Pushing | Pulling)) == 0) enqueue(connection) + if ((currentState & (InClosed | Pushing | Pulling | OutClosed)) == 0) enqueue(connection) if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } @@ -699,9 +718,50 @@ private[stream] final class GraphInterpreter( portStates(connection) = currentState | InClosed if ((currentState & OutClosed) == 0) { connectionSlots(connection) = Empty - if ((currentState & (Pulling | Pushing)) == 0) enqueue(connection) + if ((currentState & (Pulling | Pushing | InClosed)) == 0) enqueue(connection) } if ((currentState & InClosed) == 0) completeConnection(assembly.inOwners(connection)) } + /** + * Debug utility to dump the "waits-on" relationships in DOT format to the console for analysis of deadlocks. + * + * Only invoke this after the interpreter completely settled, otherwise the results might be off. This is a very + * simplistic tool, make sure you are understanding what you are doing and then it will serve you well. + */ + def dumpWaits(): Unit = { + println("digraph waits {") + + for (i ← assembly.stages.indices) { + println(s"""N$i [label="${assembly.stages(i)}"]""") + } + + def nameIn(port: Int): String = { + val owner = assembly.inOwners(port) + if (owner == Boundary) "Out" + port + else "N" + owner + } + + def nameOut(port: Int): String = { + val owner = assembly.outOwners(port) + if (owner == Boundary) "In" + port + else "N" + owner + } + + for (i ← portStates.indices) { + portStates(i) match { + case InReady ⇒ + println(s""" ${nameIn(i)} -> ${nameOut(i)} [label="shouldPull"; color=blue]; """) + case OutReady ⇒ + println(s""" ${nameOut(i)} -> ${nameIn(i)} [label="shouldPush"; color=red]; """) + case x if (x | InClosed | OutClosed) == (InClosed | OutClosed) ⇒ + println(s""" ${nameIn(i)} -> ${nameOut(i)} [style=dotted; label="closed" dir=both]; """) + case _ ⇒ + } + + } + + println("}") + println(s"// $queueStatus (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala index 0143445314..6d23909c6a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -23,7 +23,7 @@ private[akka] object IteratorInterpreter { setHandler(out, new OutHandler { override def onPull(): Unit = { - if (!hasNext) completeStage() + if (!hasNext) complete(out) else { val elem = input.next() hasNext = input.hasNext @@ -34,7 +34,7 @@ private[akka] object IteratorInterpreter { } } - override def onDownstreamFinish(): Unit = completeStage() + override def onDownstreamFinish(): Unit = () }) override def toString = "IteratorUpstream" @@ -57,13 +57,11 @@ private[akka] object IteratorInterpreter { override def onUpstreamFinish(): Unit = { done = true - completeStage() } override def onUpstreamFailure(cause: Throwable): Unit = { done = true lastFailure = cause - completeStage() } }) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index a52a59ef6d..e9d79b8def 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -8,7 +8,7 @@ import akka.event.{ LogSource, Logging, LoggingAdapter } import akka.stream.Attributes.{ InputBuffer, LogLevels } import akka.stream.DelayOverflowStrategy.EmitEarly import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage -import akka.stream.impl.{ FixedSizeBuffer, ReactiveStreamsCompliance } +import akka.stream.impl.{ FixedSizeBuffer, BoundedBuffer, ReactiveStreamsCompliance } import akka.stream.stage._ import akka.stream.{ Supervision, _ } import scala.annotation.tailrec @@ -523,6 +523,7 @@ private[akka] final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapol * INTERNAL API */ private[akka] object MapAsync { + final class Holder[T](var elem: T) val NotYetThere = Failure(new Exception) } @@ -547,38 +548,38 @@ private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Fut inheritedAttributes.getAttribute(classOf[SupervisionStrategy]) .map(_.decider).getOrElse(Supervision.stoppingDecider) - val buffer = FixedSizeBuffer[Try[Out]](parallelism) + val buffer = new BoundedBuffer[Holder[Try[Out]]](parallelism) def todo = buffer.used @tailrec private def pushOne(): Unit = if (buffer.isEmpty) { if (isClosed(in)) completeStage() else if (!hasBeenPulled(in)) pull(in) - } else if (buffer.peek == NotYetThere) { + } else if (buffer.peek.elem == NotYetThere) { if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) - } else buffer.dequeue() match { + } else buffer.dequeue().elem match { case Failure(ex) ⇒ pushOne() case Success(elem) ⇒ push(out, elem) if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) } - def failOrPull(idx: Int, f: Failure[Out]) = + def failOrPull(holder: Holder[Try[Out]], f: Failure[Out]) = if (decider(f.exception) == Supervision.Stop) failStage(f.exception) else { - buffer.put(idx, f) + holder.elem = f if (isAvailable(out)) pushOne() } val futureCB = - getAsyncCallback[(Int, Try[Out])]({ - case (idx, f: Failure[_]) ⇒ failOrPull(idx, f) - case (idx, s @ Success(elem)) ⇒ + getAsyncCallback[(Holder[Try[Out]], Try[Out])]({ + case (holder, f: Failure[_]) ⇒ failOrPull(holder, f) + case (holder, s @ Success(elem)) ⇒ if (elem == null) { val ex = ReactiveStreamsCompliance.elementMustNotBeNullException - failOrPull(idx, Failure(ex)) + failOrPull(holder, Failure(ex)) } else { - buffer.put(idx, s) + holder.elem = s if (isAvailable(out)) pushOne() } }) @@ -587,8 +588,9 @@ private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Fut override def onPush(): Unit = { try { val future = f(grab(in)) - val idx = buffer.enqueue(NotYetThere) - future.onComplete(result ⇒ futureCB.invoke(idx -> result))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) + val holder = new Holder[Try[Out]](NotYetThere) + buffer.enqueue(holder) + future.onComplete(result ⇒ futureCB.invoke(holder -> result))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) } catch { case NonFatal(ex) ⇒ if (decider(ex) == Supervision.Stop) failStage(ex) @@ -626,7 +628,7 @@ private[akka] final case class MapAsyncUnordered[In, Out](parallelism: Int, f: I .map(_.decider).getOrElse(Supervision.stoppingDecider) var inFlight = 0 - val buffer = FixedSizeBuffer[Out](parallelism) + val buffer = new BoundedBuffer[Out](parallelism) def todo = inFlight + buffer.used def failOrPull(ex: Throwable) = diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala index e5783b899e..923f3b1a31 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala @@ -4,7 +4,6 @@ package akka.stream.impl.fusing import java.util.concurrent.atomic.AtomicReference - import akka.stream._ import akka.stream.impl.SubscriptionTimeoutException import akka.stream.stage._ @@ -17,6 +16,12 @@ import java.{ util ⇒ ju } import scala.collection.immutable import scala.concurrent._ import scala.concurrent.duration.FiniteDuration +import scala.util.control.NonFatal +import akka.stream.impl.MultiStreamOutputProcessor.SubstreamSubscriptionTimeout +import scala.annotation.tailrec +import akka.stream.impl.PublisherSource +import akka.stream.impl.CancellingSubscriber +import akka.stream.impl.BoundedBuffer /** * INTERNAL API @@ -30,58 +35,15 @@ final class FlattenMerge[T, M](breadth: Int) extends GraphStage[FlowShape[Graph[ override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { - import StreamOfStreams.{ LocalSink, LocalSource } - - var sources = Set.empty[LocalSource[T]] + var sources = Set.empty[SubSinkInlet[T]] def activeSources = sources.size - private sealed trait Queue { - def hasData: Boolean - def enqueue(src: LocalSource[T]): Unit - def dequeue(): LocalSource[T] - } - - private final class FixedQueue extends Queue { - final val Size = 16 - final val Mask = 15 - - private val queue = new Array[LocalSource[T]](Size) - private var head = 0 - private var tail = 0 - - def hasData = tail != head - def enqueue(src: LocalSource[T]): Unit = - if (tail - head == Size) { - val queue = new DynamicQueue - while (hasData) { - queue.add(dequeue()) - } - queue.add(src) - q = queue - } else { - queue(tail & Mask) = src - tail += 1 - } - def dequeue(): LocalSource[T] = { - val ret = queue(head & Mask) - head += 1 - ret - } - } - - private final class DynamicQueue extends ju.LinkedList[LocalSource[T]] with Queue { - def hasData = !isEmpty() - def enqueue(src: LocalSource[T]): Unit = add(src) - def dequeue(): LocalSource[T] = remove() - } - - private var q: Queue = new FixedQueue + val q = new BoundedBuffer[SubSinkInlet[T]](breadth) def pushOut(): Unit = { val src = q.dequeue() - push(out, src.elem) - src.elem = null.asInstanceOf[T] - if (src.isActive) src.pull() + push(out, src.grab()) + if (!src.isClosed) src.pull() else removeSource(src) } @@ -103,33 +65,30 @@ final class FlattenMerge[T, M](breadth: Int) extends GraphStage[FlowShape[Graph[ val outHandler = new OutHandler { // could be unavailable due to async input having been executed before this notification - override def onPull(): Unit = if (q.hasData && isAvailable(out)) pushOut() + override def onPull(): Unit = if (q.nonEmpty && isAvailable(out)) pushOut() } def addSource(source: Graph[SourceShape[T], M]): Unit = { - val localSource = new LocalSource[T]() - sources += localSource - val subF = Source.fromGraph(source) - .runWith(new LocalSink(getAsyncCallback[ActorSubscriberMessage] { - case OnNext(elem) ⇒ - val elemT = elem.asInstanceOf[T] - if (isAvailable(out)) { - push(out, elemT) - localSource.pull() - } else { - localSource.elem = elemT - q.enqueue(localSource) - } - case OnComplete ⇒ - localSource.deactivate() - if (localSource.elem == null) removeSource(localSource) - case OnError(ex) ⇒ - failStage(ex) - }.invoke))(interpreter.subFusingMaterializer) - localSource.activate(subF) + val sinkIn = new SubSinkInlet[T]("FlattenMergeSink") + sinkIn.setHandler(new InHandler { + override def onPush(): Unit = { + if (isAvailable(out)) { + push(out, sinkIn.grab()) + sinkIn.pull() + } else { + q.enqueue(sinkIn) + } + } + override def onUpstreamFinish(): Unit = { + if (!sinkIn.isAvailable) removeSource(sinkIn) + } + }) + sinkIn.pull() + sources += sinkIn + Source.fromGraph(source).runWith(sinkIn.sink)(interpreter.subFusingMaterializer) } - def removeSource(src: LocalSource[T]): Unit = { + def removeSource(src: SubSinkInlet[T]): Unit = { val pullSuppressed = activeSources == breadth sources -= src if (pullSuppressed) tryPull(in) @@ -144,158 +103,6 @@ final class FlattenMerge[T, M](breadth: Int) extends GraphStage[FlowShape[Graph[ override def toString: String = s"FlattenMerge($breadth)" } -/** - * INTERNAL API - */ -private[fusing] object StreamOfStreams { - import akka.dispatch.ExecutionContexts.sameThreadExecutionContext - private val RequestOne = Request(1) // No need to frivolously allocate these - private type LocalSinkSubscription = ActorPublisherMessage ⇒ Unit - /** - * INTERNAL API - */ - private[fusing] final class LocalSource[T] { - private var subF: Future[LocalSinkSubscription] = _ - private var sub: LocalSinkSubscription = _ - - var elem: T = null.asInstanceOf[T] - - def isActive: Boolean = sub ne null - - def deactivate(): Unit = { - sub = null - subF = null - } - - def activate(f: Future[LocalSinkSubscription]): Unit = { - subF = f - /* - * The subscription is communicated to the FlattenMerge stage by way of completing - * the future. Encoding it like this means that the `sub` field will be written - * either by us (if the future has already been completed) or by the LocalSink (when - * it eventually completes the future in its `preStart`). The important part is that - * either way the `sub` field is populated before we get the first `OnNext` message - * and the value is safely published in either case as well (since AsyncCallback is - * based on an Actor message send). - */ - f.foreach(s ⇒ sub = s)(sameThreadExecutionContext) - } - - def pull(): Unit = { - if (sub ne null) sub(RequestOne) - else if (subF eq null) throw new IllegalStateException("not yet initialized, subscription future not set") - else throw new IllegalStateException("not yet initialized, subscription future has " + subF.value) - } - - def cancel(): Unit = - if (subF ne null) - subF.foreach(_(Cancel))(sameThreadExecutionContext) - } - - /** - * INTERNAL API - */ - private[fusing] final class LocalSink[T](notifier: ActorSubscriberMessage ⇒ Unit) - extends GraphStageWithMaterializedValue[SinkShape[T], Future[LocalSinkSubscription]] { - - private val in = Inlet[T]("LocalSink.in") - - override def initialAttributes = Attributes.name("LocalSink") - override val shape = SinkShape(in) - - override def createLogicAndMaterializedValue(attr: Attributes): (GraphStageLogic, Future[LocalSinkSubscription]) = { - val sub = Promise[LocalSinkSubscription] - val logic = new GraphStageLogic(shape) { - setHandler(in, new InHandler { - override def onPush(): Unit = notifier(OnNext(grab(in))) - - override def onUpstreamFinish(): Unit = notifier(OnComplete) - - override def onUpstreamFailure(ex: Throwable): Unit = notifier(OnError(ex)) - }) - - override def preStart(): Unit = { - pull(in) - sub.success( - getAsyncCallback[ActorPublisherMessage] { - case RequestOne ⇒ tryPull(in) - case Cancel ⇒ completeStage() - case _ ⇒ throw new IllegalStateException("Bug") - }.invoke) - } - } - logic -> sub.future - } - } -} - -/** - * INTERNAL API - */ -object PrefixAndTail { - - sealed trait MaterializationState - case object NotMaterialized extends MaterializationState - case object AlreadyMaterialized extends MaterializationState - case object TimedOut extends MaterializationState - - case object NormalCompletion extends MaterializationState - case class FailureCompletion(ex: Throwable) extends MaterializationState - - trait TailInterface[T] { - def pushSubstream(elem: T): Unit - def completeSubstream(): Unit - def failSubstream(ex: Throwable) - } - - final class TailSource[T]( - timeout: FiniteDuration, - register: TailInterface[T] ⇒ Unit, - pullParent: Unit ⇒ Unit, - cancelParent: Unit ⇒ Unit) extends GraphStage[SourceShape[T]] { - val out: Outlet[T] = Outlet("Tail.out") - val materializationState = new AtomicReference[MaterializationState](NotMaterialized) - override val shape: SourceShape[T] = SourceShape(out) - - private final class TailSourceLogic(_shape: Shape) extends GraphStageLogic(_shape) with OutHandler with TailInterface[T] { - setHandler(out, this) - - override def preStart(): Unit = { - materializationState.getAndSet(AlreadyMaterialized) match { - case AlreadyMaterialized ⇒ - failStage(new IllegalStateException("Tail Source cannot be materialized more than once.")) - case TimedOut ⇒ - // Already detached from parent - failStage(new SubscriptionTimeoutException(s"Tail Source has not been materialized in $timeout.")) - case NormalCompletion ⇒ - // Already detached from parent - completeStage() - case FailureCompletion(ex) ⇒ - // Already detached from parent - failStage(ex) - case NotMaterialized ⇒ - register(this) - } - - } - - private val onParentPush = getAsyncCallback[T](push(out, _)) - private val onParentFinish = getAsyncCallback[Unit](_ ⇒ completeStage()) - private val onParentFailure = getAsyncCallback[Throwable](failStage) - - override def pushSubstream(elem: T): Unit = onParentPush.invoke(elem) - override def completeSubstream(): Unit = onParentFinish.invoke(()) - override def failSubstream(ex: Throwable): Unit = onParentFailure.invoke(ex) - - override def onPull(): Unit = pullParent(()) - override def onDownstreamFinish(): Unit = cancelParent(()) - } - - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TailSourceLogic(shape) - } - -} - /** * INTERNAL API */ @@ -307,53 +114,47 @@ final class PrefixAndTail[T](n: Int) extends GraphStage[FlowShape[T, (immutable. override def initialAttributes = Attributes.name("PrefixAndTail") private final class PrefixAndTailLogic(_shape: Shape) extends TimerGraphStageLogic(_shape) with OutHandler with InHandler { - import PrefixAndTail._ private var left = if (n < 0) 0 else n private var builder = Vector.newBuilder[T] - private var tailSource: TailSource[T] = null - private var tail: TailInterface[T] = null builder.sizeHint(left) - private var pendingCompletion: MaterializationState = null + + private var tailSource: SubSourceOutlet[T] = null private val SubscriptionTimer = "SubstreamSubscriptionTimer" - private val onSubstreamPull = getAsyncCallback[Unit](_ ⇒ pull(in)) - private val onSubstreamFinish = getAsyncCallback[Unit](_ ⇒ completeStage()) - private val onSubstreamRegister = getAsyncCallback[TailInterface[T]] { tailIf ⇒ - tail = tailIf - cancelTimer(SubscriptionTimer) - pendingCompletion match { - case NormalCompletion ⇒ - tail.completeSubstream() - completeStage() - case FailureCompletion(ex) ⇒ - tail.failSubstream(ex) - completeStage() - case _ ⇒ + override protected def onTimer(timerKey: Any): Unit = { + val timeout = ActorMaterializer.downcast(interpreter.materializer).settings.subscriptionTimeoutSettings.timeout + tailSource.timeout(timeout) + if (tailSource.isClosed) completeStage() + } + + private def prefixComplete = builder eq null + + private def subHandler = new OutHandler { + override def onPull(): Unit = { + setKeepGoing(false) + cancelTimer(SubscriptionTimer) + pull(in) + tailSource.setHandler(new OutHandler { + override def onPull(): Unit = pull(in) + }) } } - override protected def onTimer(timerKey: Any): Unit = - if (tailSource.materializationState.compareAndSet(NotMaterialized, TimedOut)) completeStage() - - private def prefixComplete = builder eq null - private def waitingSubstreamRegistration = tail eq null - private def openSubstream(): Source[T, Unit] = { val timeout = ActorMaterializer.downcast(interpreter.materializer).settings.subscriptionTimeoutSettings.timeout - tailSource = new TailSource[T](timeout, onSubstreamRegister.invoke, onSubstreamPull.invoke, onSubstreamFinish.invoke) + tailSource = new SubSourceOutlet[T]("TailSource") + tailSource.setHandler(subHandler) + setKeepGoing(true) scheduleOnce(SubscriptionTimer, timeout) builder = null - Source.fromGraph(tailSource) + Source.fromGraph(tailSource.source) } - // Needs to keep alive if upstream completes but substream has been not yet materialized - override def keepGoingAfterAllPortsClosed: Boolean = true - override def onPush(): Unit = { if (prefixComplete) { - tail.pushSubstream(grab(in)) + tailSource.push(grab(in)) } else { builder += grab(in) left -= 1 @@ -375,33 +176,15 @@ final class PrefixAndTail[T](n: Int) extends GraphStage[FlowShape[T, (immutable. // This handles the unpulled out case as well emit(out, (builder.result, Source.empty), () ⇒ completeStage()) } else { - if (waitingSubstreamRegistration) { - // Detach if possible. - // This allows this stage to complete without waiting for the substream to be materialized, since that - // is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be - // able to signal completion normally soon. - if (tailSource.materializationState.compareAndSet(NotMaterialized, NormalCompletion)) completeStage() - else pendingCompletion = NormalCompletion - } else { - tail.completeSubstream() - completeStage() - } + if (!tailSource.isClosed) tailSource.complete() + completeStage() } } override def onUpstreamFailure(ex: Throwable): Unit = { if (prefixComplete) { - if (waitingSubstreamRegistration) { - // Detach if possible. - // This allows this stage to complete without waiting for the substream to be materialized, since that - // is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be - // able to signal completion normally soon. - if (tailSource.materializationState.compareAndSet(NotMaterialized, FailureCompletion(ex))) failStage(ex) - else pendingCompletion = FailureCompletion(ex) - } else { - tail.failSubstream(ex) - completeStage() - } + if (!tailSource.isClosed) tailSource.fail(ex) + completeStage() } else failStage(ex) } @@ -418,3 +201,327 @@ final class PrefixAndTail[T](n: Int) extends GraphStage[FlowShape[T, (immutable. override def toString: String = s"PrefixAndTail($n)" } + +/** + * INERNAL API + */ +object Split { + sealed abstract class SplitDecision + + /** Splits before the current element. The current element will be the first element in the new substream. */ + case object SplitBefore extends SplitDecision + + /** Splits after the current element. The current element will be the last element in the current substream. */ + case object SplitAfter extends SplitDecision + + def when[T](p: T ⇒ Boolean): Graph[FlowShape[T, Source[T, Unit]], Unit] = new Split(Split.SplitBefore, p) + def after[T](p: T ⇒ Boolean): Graph[FlowShape[T, Source[T, Unit]], Unit] = new Split(Split.SplitAfter, p) +} + +/** + * INERNAL API + */ +final class Split[T](decision: Split.SplitDecision, p: T ⇒ Boolean) extends GraphStage[FlowShape[T, Source[T, Unit]]] { + val in: Inlet[T] = Inlet("Split.in") + val out: Outlet[Source[T, Unit]] = Outlet("Split.out") + override val shape: FlowShape[T, Source[T, Unit]] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { + import Split._ + + private val SubscriptionTimer = "SubstreamSubscriptionTimer" + + private var timeout: FiniteDuration = _ + private var substreamSource: SubSourceOutlet[T] = null + private var substreamPushed = false + private var substreamCancelled = false + + override def preStart(): Unit = { + timeout = ActorMaterializer.downcast(interpreter.materializer).settings.subscriptionTimeoutSettings.timeout + } + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (substreamSource eq null) pull(in) + else if (!substreamPushed) { + push(out, Source.fromGraph(substreamSource.source)) + scheduleOnce(SubscriptionTimer, timeout) + substreamPushed = true + } + } + + override def onDownstreamFinish(): Unit = { + // If the substream is already cancelled or it has not been handed out, we can go away + if (!substreamPushed || substreamCancelled) completeStage() + } + }) + + // initial input handler + setHandler(in, new InHandler { + override def onPush(): Unit = { + val handler = new SubstreamHandler + val elem = grab(in) + + decision match { + case SplitAfter if p(elem) ⇒ + push(out, Source.single(elem)) + // Next pull will come from the next substream that we will open + case _ ⇒ + handler.firstElem = elem + } + + handOver(handler) + } + override def onUpstreamFinish(): Unit = completeStage() + }) + + private def handOver(handler: SubstreamHandler): Unit = { + if (isClosed(out)) completeStage() + else { + substreamSource = new SubSourceOutlet[T]("SplitSource") + substreamSource.setHandler(handler) + substreamCancelled = false + setHandler(in, handler) + setKeepGoing(enabled = handler.hasInitialElement) + + if (isAvailable(out)) { + push(out, Source.fromGraph(substreamSource.source)) + scheduleOnce(SubscriptionTimer, timeout) + substreamPushed = true + } else substreamPushed = false + } + } + + override protected def onTimer(timerKey: Any): Unit = substreamSource.timeout(timeout) + + private class SubstreamHandler extends InHandler with OutHandler { + + var firstElem: T = null.asInstanceOf[T] + + def hasInitialElement: Boolean = firstElem.asInstanceOf[AnyRef] ne null + private var willCompleteAfterInitialElement = false + + // Substreams are always assumed to be pushable position when we enter this method + private def closeThis(handler: SubstreamHandler, currentElem: T): Unit = { + decision match { + case SplitAfter ⇒ + if (!substreamCancelled) { + substreamSource.push(currentElem) + substreamSource.complete() + } + case SplitBefore ⇒ + handler.firstElem = currentElem + if (!substreamCancelled) substreamSource.complete() + } + } + + override def onPull(): Unit = { + if (hasInitialElement) { + substreamSource.push(firstElem) + firstElem = null.asInstanceOf[T] + setKeepGoing(false) + if (willCompleteAfterInitialElement) { + substreamSource.complete() + completeStage() + } + } else pull(in) + } + + override def onDownstreamFinish(): Unit = { + substreamCancelled = true + if (isClosed(in)) completeStage() + else { + // Start draining + if (!hasBeenPulled(in)) pull(in) + } + } + + override def onPush(): Unit = { + val elem = grab(in) + try { + if (p(elem)) { + val handler = new SubstreamHandler + closeThis(handler, elem) + handOver(handler) + } else { + // Drain into the void + if (substreamCancelled) pull(in) + else substreamSource.push(elem) + } + } catch { + case NonFatal(ex) ⇒ onUpstreamFailure(ex) + } + } + + override def onUpstreamFinish(): Unit = + if (hasInitialElement) willCompleteAfterInitialElement = true + else { + substreamSource.complete() + completeStage() + } + + override def onUpstreamFailure(ex: Throwable): Unit = { + substreamSource.fail(ex) + failStage(ex) + } + + } + } +} + +/** + * INTERNAL API + */ +object SubSink { + val RequestOne = Request(1) // No need to frivolously allocate these +} + +/** + * INTERNAL API + */ +final class SubSink[T](name: String, externalCallback: ActorSubscriberMessage ⇒ Unit) + extends GraphStage[SinkShape[T]] { + import SubSink._ + + private val in = Inlet[T]("SubSink.in") + + override def initialAttributes = Attributes.name(s"SubSink($name)") + override val shape = SinkShape(in) + + val status = new AtomicReference[AnyRef] + + def pullSubstream(): Unit = status.get match { + case f: AsyncCallback[Any] @unchecked ⇒ f.invoke(RequestOne) + case null ⇒ + if (!status.compareAndSet(null, RequestOne)) + status.get.asInstanceOf[ActorPublisherMessage ⇒ Unit](RequestOne) + } + + def cancelSubstream(): Unit = status.get match { + case f: AsyncCallback[Any] @unchecked ⇒ f.invoke(Cancel) + case x ⇒ // a potential RequestOne is overwritten + if (!status.compareAndSet(x, Cancel)) + status.get.asInstanceOf[ActorPublisherMessage ⇒ Unit](Cancel) + } + + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler { + setHandler(in, this) + + override def onPush(): Unit = externalCallback(OnNext(grab(in))) + override def onUpstreamFinish(): Unit = externalCallback(OnComplete) + override def onUpstreamFailure(ex: Throwable): Unit = externalCallback(OnError(ex)) + + @tailrec private def setCB(cb: AsyncCallback[ActorPublisherMessage]): Unit = { + status.get match { + case null ⇒ + if (!status.compareAndSet(null, cb)) setCB(cb) + case RequestOne ⇒ + pull(in) + if (!status.compareAndSet(RequestOne, cb)) setCB(cb) + case Cancel ⇒ + completeStage() + if (!status.compareAndSet(Cancel, cb)) setCB(cb) + case _: AsyncCallback[_] ⇒ + failStage(new IllegalStateException("Substream Source cannot be materialized more than once")) + } + } + + override def preStart(): Unit = { + val ourOwnCallback = getAsyncCallback[ActorPublisherMessage] { + case RequestOne ⇒ tryPull(in) + case Cancel ⇒ completeStage() + case _ ⇒ throw new IllegalStateException("Bug") + } + setCB(ourOwnCallback) + } + } + + override def toString: String = name +} + +object SubSource { + /** + * INTERNAL API + * + * HERE ACTUALLY ARE DRAGONS, YOU HAVE BEEN WARNED! + * + * FIXME #19240 + */ + private[akka] def kill[T, M](s: Source[T, M]): Unit = { + s.module match { + case GraphStageModule(_, _, stage: SubSource[_]) ⇒ + stage.externalCallback.invoke(Cancel) + case pub: PublisherSource[_] ⇒ + pub.create(null)._1.subscribe(new CancellingSubscriber) + case m ⇒ + GraphInterpreter.currentInterpreterOrNull match { + case null ⇒ throw new UnsupportedOperationException(s"cannot drop Source of type ${m.getClass.getName}") + case intp ⇒ s.runWith(Sink.ignore)(intp.subFusingMaterializer) + } + } + } +} + +/** + * INTERNAL API + */ +final class SubSource[T](name: String, private[fusing] val externalCallback: AsyncCallback[ActorPublisherMessage]) + extends GraphStage[SourceShape[T]] { + import SubSink._ + + val out: Outlet[T] = Outlet("SubSource.out") + override def initialAttributes = Attributes.name(s"SubSource($name)") + override val shape: SourceShape[T] = SourceShape(out) + + val status = new AtomicReference[AnyRef] + + def pushSubstream(elem: T): Unit = status.get match { + case f: AsyncCallback[Any] @unchecked ⇒ f.invoke(OnNext(elem)) + case _ ⇒ throw new IllegalStateException("cannot push to uninitialized substream") + } + + def completeSubstream(): Unit = status.get match { + case f: AsyncCallback[Any] @unchecked ⇒ f.invoke(OnComplete) + case null ⇒ + if (!status.compareAndSet(null, OnComplete)) + status.get.asInstanceOf[AsyncCallback[Any]].invoke(OnComplete) + } + + def failSubstream(ex: Throwable): Unit = status.get match { + case f: AsyncCallback[Any] @unchecked ⇒ f.invoke(OnError(ex)) + case null ⇒ + val failure = OnError(ex) + if (!status.compareAndSet(null, failure)) + status.get.asInstanceOf[AsyncCallback[Any]].invoke(failure) + } + + def timeout(d: FiniteDuration): Boolean = + status.compareAndSet(null, OnError(new SubscriptionTimeoutException(s"Substream Source has not been materialized in $d"))) + + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with OutHandler { + setHandler(out, this) + + @tailrec private def setCB(cb: AsyncCallback[ActorSubscriberMessage]): Unit = { + status.get match { + case null ⇒ if (!status.compareAndSet(null, cb)) setCB(cb) + case OnComplete ⇒ completeStage() + case OnError(ex) ⇒ failStage(ex) + case _: AsyncCallback[_] ⇒ failStage(new IllegalStateException("Substream Source cannot be materialized more than once")) + } + } + + override def preStart(): Unit = { + val ourOwnCallback = getAsyncCallback[ActorSubscriberMessage] { + case OnComplete ⇒ completeStage() + case OnError(ex) ⇒ failStage(ex) + case OnNext(elem) ⇒ push(out, elem.asInstanceOf[T]) + } + setCB(ourOwnCallback) + } + + override def onPull(): Unit = externalCallback.invoke(RequestOne) + override def onDownstreamFinish(): Unit = externalCallback.invoke(Cancel) + } + + override def toString: String = name +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index c346720adb..48c7de3a34 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -93,9 +93,6 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, override def onDownstreamFinish(): Unit = tryUnbind() }) - // because when we tryUnbind, we must wait for the Ubound signal before terminating - override def keepGoingAfterAllPortsClosed = true - private def connectionFor(connected: Connected, connection: ActorRef): StreamTcp.IncomingConnection = { connectionFlowsAwaitingInitialization.incrementAndGet() @@ -122,6 +119,7 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, private def tryUnbind(): Unit = { if (listener ne null) { self.unwatch(listener) + setKeepGoing(true) listener ! Unbind } } @@ -164,7 +162,7 @@ private[stream] object TcpConnectionStage { case class Inbound(connection: ActorRef, halfClose: Boolean) extends TcpRole /* - * This is a *non-deatched* design, i.e. this does not prefetch itself any of the inputs. It relies on downstream + * This is a *non-detached* design, i.e. this does not prefetch itself any of the inputs. It relies on downstream * stages to provide the necessary prefetch on `bytesOut` and the framework to do the proper prefetch in the buffer * backing `bytesIn`. If prefetch on `bytesOut` is required (i.e. user stages cannot be trusted) then it is better * to attach an extra, fused buffer to the end of this flow. Keeping this stage non-detached makes it much simpler and @@ -182,18 +180,21 @@ private[stream] object TcpConnectionStage { override def onPull(): Unit = () }) - override def preStart(): Unit = role match { - case Inbound(conn, _) ⇒ - setHandler(bytesOut, readHandler) - self = getStageActorRef(connected) - connection = conn - self.watch(connection) - connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) - pull(bytesIn) - case ob @ Outbound(manager, cmd, _, _) ⇒ - self = getStageActorRef(connecting(ob)) - self.watch(manager) - manager ! cmd + override def preStart(): Unit = { + setKeepGoing(true) + role match { + case Inbound(conn, _) ⇒ + setHandler(bytesOut, readHandler) + self = getStageActorRef(connected) + connection = conn + self.watch(connection) + connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) + pull(bytesIn) + case ob @ Outbound(manager, cmd, _, _) ⇒ + self = getStageActorRef(connecting(ob)) + self.watch(manager) + manager ! cmd + } } private def connecting(ob: Outbound)(evt: (ActorRef, Any)): Unit = { @@ -270,8 +271,6 @@ private[stream] object TcpConnectionStage { } }) - override def keepGoingAfterAllPortsClosed: Boolean = true - override def postStop(): Unit = role match { case Outbound(_, _, localAddressPromise, _) ⇒ // Fail if has not been completed with an address eariler diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala index 8042526b55..fd12c42a8f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala @@ -31,7 +31,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu * value of the current flow (ignoring the other BidiFlow’s value), use * [[BidiFlow#atopMat atopMat]] if a different strategy is needed. */ - def atop[OO1, II2, Mat2](bidi: BidiFlow[O1, OO1, II2, I2, Mat2]): BidiFlow[I1, OO1, II2, O2, Mat] = atopMat(bidi)(Keep.left) + def atop[OO1, II2, Mat2](bidi: Graph[BidiShape[O1, OO1, II2, I2], Mat2]): BidiFlow[I1, OO1, II2, O2, Mat] = atopMat(bidi)(Keep.left) /** * Add the given BidiFlow as the next step in a bidirectional transformation @@ -51,7 +51,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu * The `combine` function is used to compose the materialized values of this flow and that * flow into the materialized value of the resulting BidiFlow. */ - def atopMat[OO1, II2, Mat2, M](bidi: BidiFlow[O1, OO1, II2, I2, Mat2])(combine: (Mat, Mat2) ⇒ M): BidiFlow[I1, OO1, II2, O2, M] = { + def atopMat[OO1, II2, Mat2, M](bidi: Graph[BidiShape[O1, OO1, II2, I2], Mat2])(combine: (Mat, Mat2) ⇒ M): BidiFlow[I1, OO1, II2, O2, M] = { val copy = bidi.module.carbonCopy val ins = copy.shape.inlets val outs = copy.shape.outlets @@ -81,7 +81,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu * value of the current flow (ignoring the other Flow’s value), use * [[BidiFlow#joinMat joinMat]] if a different strategy is needed. */ - def join[Mat2](flow: Flow[O1, I2, Mat2]): Flow[I1, O2, Mat] = joinMat(flow)(Keep.left) + def join[Mat2](flow: Graph[FlowShape[O1, I2], Mat2]): Flow[I1, O2, Mat] = joinMat(flow)(Keep.left) /** * Add the given Flow as the final step in a bidirectional transformation @@ -101,7 +101,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu * The `combine` function is used to compose the materialized values of this flow and that * flow into the materialized value of the resulting [[Flow]]. */ - def joinMat[Mat2, M](flow: Flow[O1, I2, Mat2])(combine: (Mat, Mat2) ⇒ M): Flow[I1, O2, M] = { + def joinMat[Mat2, M](flow: Graph[FlowShape[O1, I2], Mat2])(combine: (Mat, Mat2) ⇒ M): Flow[I1, O2, M] = { val copy = flow.module.carbonCopy val in = copy.shape.inlets.head val out = copy.shape.outlets.head diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index d29f35d658..ed1c966b65 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -6,7 +6,7 @@ package akka.stream.scaladsl import akka.event.LoggingAdapter import akka.stream.Attributes._ import akka.stream._ -import akka.stream.impl.Stages.{ DirectProcessor, StageModule, SymbolicGraphStage } +import akka.stream.impl.Stages.{ DirectProcessor, StageModule } import akka.stream.impl.StreamLayout.{ EmptyModule, Module } import akka.stream.impl._ import akka.stream.impl.fusing._ @@ -1076,12 +1076,12 @@ trait FlowOps[+Out, +Mat] { def splitWhen(p: Out ⇒ Boolean): SubFlow[Out, Mat, Repr, Closed] = { val merge = new SubFlowImpl.MergeBack[Out, Repr] { override def apply[T](flow: Flow[Out, T, Unit], breadth: Int): Repr[T] = - deprecatedAndThen[Source[Out, Unit]](Split.when(p.asInstanceOf[Any ⇒ Boolean])) + via(Split.when(p)) .map(_.via(flow)) .via(new FlattenMerge(breadth)) } val finish: (Sink[Out, Unit]) ⇒ Closed = s ⇒ - deprecatedAndThen[Source[Out, Unit]](Split.when(p.asInstanceOf[Any ⇒ Boolean])) + via(Split.when(p)) .to(Sink.foreach(_.runWith(s)(GraphInterpreter.currentInterpreter.materializer))) new SubFlowImpl(Flow[Out], merge, finish) } @@ -1133,12 +1133,12 @@ trait FlowOps[+Out, +Mat] { def splitAfter(p: Out ⇒ Boolean): SubFlow[Out, Mat, Repr, Closed] = { val merge = new SubFlowImpl.MergeBack[Out, Repr] { override def apply[T](flow: Flow[Out, T, Unit], breadth: Int): Repr[T] = - deprecatedAndThen[Source[Out, Unit]](Split.after(p.asInstanceOf[Any ⇒ Boolean])) + via(Split.after(p)) .map(_.via(flow)) .via(new FlattenMerge(breadth)) } val finish: (Sink[Out, Unit]) ⇒ Closed = s ⇒ - deprecatedAndThen[Source[Out, Unit]](Split.after(p.asInstanceOf[Any ⇒ Boolean])) + via(Split.after(p)) .to(Sink.foreach(_.runWith(s)(GraphInterpreter.currentInterpreter.materializer))) new SubFlowImpl(Flow[Out], merge, finish) } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index d1e541ed87..a4f328ca79 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -3,21 +3,24 @@ */ package akka.stream.stage -import java.util.concurrent.atomic.{ AtomicReference } - +import java.util +import java.util.concurrent.atomic.{ AtomicReferenceFieldUpdater, AtomicReference } import akka.actor._ import akka.dispatch.sysmsg.{ DeathWatchNotification, SystemMessage, Unwatch, Watch } import akka.event.LoggingAdapter import akka.japi.function.{ Effect, Procedure } import akka.stream._ import akka.stream.impl.StreamLayout.Module -import akka.stream.impl.fusing.{ GraphInterpreter, GraphStageModule } +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly +import akka.stream.impl.fusing.{ GraphInterpreter, GraphModule, GraphStageModule, SubSource, SubSink } import akka.stream.impl.{ ReactiveStreamsCompliance, SeqActorName } - import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration +import akka.stream.impl.SubscriptionTimeoutException +import akka.stream.actor.ActorSubscriberMessage +import akka.stream.actor.ActorPublisherMessage abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, M] { @@ -507,6 +510,16 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } } + /** + * Controls whether this stage shall shut down when all its ports are closed, which + * is the default. In order to have it keep going past that point this method needs + * to be called with a `true` argument before all ports are closed, and afterwards + * it will not be closed until this method is called with a `false` argument or the + * stage is terminated via `completeStage()` or `failStage()`. + */ + final protected def setKeepGoing(enabled: Boolean): Unit = + interpreter.setKeepGoing(this, enabled) + /** * Signals that there will be no more elements emitted on the given port. */ @@ -536,7 +549,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } i += 1 } - if (keepGoingAfterAllPortsClosed) interpreter.closeKeptAliveStageIfNeeded(stageId) + setKeepGoing(false) } /** @@ -560,7 +573,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: interpreter.fail(portToConn(i), ex, isInternal) i += 1 } - if (keepGoingAfterAllPortsClosed) interpreter.closeKeptAliveStageIfNeeded(stageId) + setKeepGoing(false) } /** @@ -918,7 +931,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * * This object can be cached and reused within the same [[GraphStageLogic]]. */ - final protected def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { + final def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { new AsyncCallback[T] { override def invoke(event: T): Unit = interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) @@ -999,10 +1012,156 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: def postStop(): Unit = () /** - * If this method returns true when all ports had been closed then the stage is not stopped until - * completeStage() or failStage() are explicitly called + * INTERNAL API + * + * This allows the dynamic creation of an Inlet for a GraphStage which is + * connected to a Sink that is available for materialization (e.g. using + * the `subFusingMaterializer`). Care needs to be taken to cancel this Inlet + * when the stage shuts down lest the corresponding Sink be left hanging. */ - def keepGoingAfterAllPortsClosed: Boolean = false + class SubSinkInlet[T](name: String) { + import ActorSubscriberMessage._ + + private var handler: InHandler = _ + private var elem: T = null.asInstanceOf[T] + private var closed = false + private var pulled = false + + private val _sink = new SubSink[T](name, getAsyncCallback[ActorSubscriberMessage] { msg ⇒ + if (!closed) msg match { + case OnNext(e) ⇒ + elem = e.asInstanceOf[T] + pulled = false + handler.onPush() + case OnComplete ⇒ + closed = true + handler.onUpstreamFinish() + case OnError(ex) ⇒ + closed = true + handler.onUpstreamFailure(ex) + } + }.invoke _) + + def sink: Graph[SinkShape[T], Unit] = _sink + + def setHandler(handler: InHandler): Unit = this.handler = handler + + def isAvailable: Boolean = elem != null + + def isClosed: Boolean = closed + + def hasBeenPulled: Boolean = pulled && !isClosed + + def grab(): T = { + require(elem != null, "cannot grab element from port when data have not yet arrived") + val ret = elem + elem = null.asInstanceOf[T] + ret + } + + def pull(): Unit = { + require(!pulled, "cannot pull port twice") + require(!closed, "cannot pull closed port") + pulled = true + _sink.pullSubstream() + } + + def cancel(): Unit = { + closed = true + _sink.cancelSubstream() + } + } + + /** + * INTERNAL API + * + * This allows the dynamic creation of an Outlet for a GraphStage which is + * connected to a Source that is available for materialization (e.g. using + * the `subFusingMaterializer`). Care needs to be taken to complete this + * Outlet when the stage shuts down lest the corresponding Sink be left + * hanging. It is good practice to use the `timeout` method to cancel this + * Outlet in case the corresponding Source is not materialized within a + * given time limit, see e.g. ActorMaterializerSettings. + */ + class SubSourceOutlet[T](name: String) { + + private var handler: OutHandler = null + private var available = false + private var closed = false + + private val callback = getAsyncCallback[ActorPublisherMessage] { + case SubSink.RequestOne ⇒ + if (!closed) { + available = true + handler.onPull() + } + case ActorPublisherMessage.Cancel ⇒ + if (!closed) { + available = false + closed = true + handler.onDownstreamFinish() + } + } + + private val _source = new SubSource[T](name, callback) + + /** + * Set the source into timed-out mode if it has not yet been materialized. + */ + def timeout(d: FiniteDuration): Unit = + if (_source.timeout(d)) closed = true + + /** + * Get the Source for this dynamic output port. + */ + def source: Graph[SourceShape[T], Unit] = _source + + /** + * Set OutHandler for this dynamic output port; this needs to be done before + * the first substream callback can arrive. + */ + def setHandler(handler: OutHandler): Unit = this.handler = handler + + /** + * Returns `true` if this output port can be pushed. + */ + def isAvailable: Boolean = available + + /** + * Returns `true` if this output port is closed, but caution + * THIS WORKS DIFFERENTLY THAN THE NORMAL isClosed(out). + * Due to possibly asynchronous shutdown it may not return + * `true` immediately after `complete()` or `fail()` have returned. + */ + def isClosed: Boolean = closed + + /** + * Push to this output port. + */ + def push(elem: T): Unit = { + available = false + _source.pushSubstream(elem) + } + + /** + * Complete this output port. + */ + def complete(): Unit = { + available = false + closed = true + _source.completeSubstream() + } + + /** + * Fail this output port. + */ + def fail(ex: Throwable): Unit = { + available = false + closed = true + _source.failSubstream(ex) + } + } + } /** @@ -1034,10 +1193,11 @@ abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shap private def onInternalTimer(scheduled: Scheduled): Unit = { val Id = scheduled.timerId - keyToTimers.get(scheduled.timerKey) match { + val timerKey = scheduled.timerKey + keyToTimers.get(timerKey) match { case Some(Timer(Id, _)) ⇒ - if (!scheduled.repeating) keyToTimers -= scheduled.timerKey - onTimer(scheduled.timerKey) + if (!scheduled.repeating) keyToTimers -= timerKey + onTimer(timerKey) case _ ⇒ } }