diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index a7351479bf..90ff251ab4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -31,22 +31,22 @@ import akka.util.ByteString * INTERNAL API * * - * HTTP pipeline setup: + * HTTP pipeline setup (without the underlying SSL/TLS (un)wrapping and the websocket switch): * - * +-------------+ +-------------+ +-----------+ - * HttpRequest | request- | Request- | | Request- | request- | ByteString - * | <-----------------+ Preparation <-----------------+ <-------------------+ Parsing <--------------- - * | | | Output | | Output | | - * | +-------------+ | | +-----------+ - * | | | - * | Application- | controller- | - * | Flow | Stage | - * | | | - * | | | +-----------+ - * | HttpResponse | | Response- | renderer- | ByteString - * v --------------------------------------------------> +-------------------> Pipeline +--------------> - * | | Rendering- | | - * +-------------+ Context +-----------+ + * +----------+ +-------------+ +-------------+ +-----------+ + * HttpRequest | | Http- | request- | Request- | | Request- | request- | ByteString + * | <------------+ <----------+ Preparation <----------+ <-------------+ Parsing <----------- + * | | | Request | | Output | | Output | | + * | | | +-------------+ | | +-----------+ + * | | | | | + * | Application- | One2One- | | controller- | + * | Flow | Bidi | | Stage | + * | | | | | + * | | | | | +-----------+ + * | HttpResponse | | HttpResponse | | Response- | renderer- | ByteString + * v -------------> +-----------------------------------> +-------------> Pipeline +----------> + * | | | | Rendering- | | + * +----------+ +-------------+ Context +-----------+ */ private[http] object HttpServerBluePrint { def apply(settings: ServerSettings, remoteAddress: Option[InetSocketAddress], log: LoggingAdapter)(implicit mat: Materializer): Http.ServerLayer = { @@ -81,7 +81,7 @@ private[http] object HttpServerBluePrint { rootParser.createShallowCopy().stage).named("rootParser") .map(establishAbsoluteUri) - val requestPreparation = + val requestPreparationFlow = Flow[RequestOutput] .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .via(headAndTailFlow) @@ -105,58 +105,58 @@ private[http] object HttpServerBluePrint { // `buffer` will ensure demand and therefore make sure that completion is reported eagerly. .buffer(1, OverflowStrategy.backpressure) - val rendererPipeline = + val responseRenderingFlow = Flow[ResponseRenderingContext] .via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer")) .flatMapConcat(ConstantFun.scalaIdentityFunction) .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger")) - BidiFlow.fromGraph(FlowGraph.create(requestParsingFlow, rendererPipeline)(Keep.none) { implicit b ⇒ - (requestParsing, renderer) ⇒ - import FlowGraph.Implicits._ + BidiFlow.fromGraph(FlowGraph.create() { implicit b ⇒ + import FlowGraph.Implicits._ - // HTTP - val requestPrep = b.add(requestPreparation) - val controllerStage = b.add(new ControllerStage(settings, log)) - val csRequestParsingIn = controllerStage.in1 - val csRequestPrepOut = controllerStage.out1 - val csHttpResponseIn = controllerStage.in2 - val csResponseCtxOut = controllerStage.out2 - requestParsing.outlet ~> csRequestParsingIn - csResponseCtxOut ~> renderer.inlet - csRequestPrepOut ~> requestPrep + // HTTP + val requestParsing = b.add(requestParsingFlow) + val requestPreparation = b.add(requestPreparationFlow) + val responseRendering = b.add(responseRenderingFlow) + val controllerStage = b.add(new ControllerStage(settings, log)) + val csRequestParsingIn = controllerStage.in1 + val csRequestPrepOut = controllerStage.out1 + val csHttpResponseIn = controllerStage.in2 + val csResponseCtxOut = controllerStage.out2 + requestParsing.outlet ~> csRequestParsingIn + csResponseCtxOut ~> responseRendering.inlet + csRequestPrepOut ~> requestPreparation - // One2OneBidi - val one2one = b.add(new One2OneBidi[HttpRequest, HttpResponse](settings.pipeliningLimit)) - requestPrep.outlet ~> one2one.in1 - one2one.out2 ~> csHttpResponseIn + // One2OneBidi + val one2one = b.add(new One2OneBidi[HttpRequest, HttpResponse](settings.pipeliningLimit)) + requestPreparation.outlet ~> one2one.in1 + one2one.out2 ~> csHttpResponseIn - // Websocket - val http = FlowShape(requestParsing.inlet, renderer.outlet) - val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2)) - val switchToWebsocket = b.add(Flow[ResponseRenderingOutput] - .collect { case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken }) - val websocket = b.add(ws.websocketFlow) - val protocolRouter = b.add(WebsocketSwitchRouter) - val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory, log)) - val wsSwitchTokenMerge = b.add(WsSwitchTokenMerge) - switchTokenBroadcast ~> switchToWebsocket - protocolRouter.out0 ~> http ~> switchTokenBroadcast ~> protocolMerge.in0 - protocolRouter.out1 ~> websocket ~> protocolMerge.in1 - switchToWebsocket ~> wsSwitchTokenMerge.in1 - wsSwitchTokenMerge.out ~> protocolRouter.in + // Websocket + val http = FlowShape(requestParsing.inlet, responseRendering.outlet) + val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2)) + val switchToWebsocket = b.add(Flow[ResponseRenderingOutput] + .collect { case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken }) + val websocket = b.add(ws.websocketFlow) + val protocolRouter = b.add(WebsocketSwitchRouter) + val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory, log)) + val wsSwitchTokenMerge = b.add(WsSwitchTokenMerge) + switchTokenBroadcast ~> switchToWebsocket ~> wsSwitchTokenMerge.in1 + protocolRouter.out0 ~> http ~> switchTokenBroadcast ~> protocolMerge.in0 + protocolRouter.out1 ~> websocket ~> protocolMerge.in1 + wsSwitchTokenMerge.out ~> protocolRouter.in - // SSL/TLS - val unwrapTls = b.add(Flow[SslTlsInbound] collect { case x: SessionBytes ⇒ x.bytes }) - val wrapTls = b.add(Flow[ByteString].map[SslTlsOutbound](SendBytes)) - unwrapTls ~> wsSwitchTokenMerge.in0 - protocolMerge.out ~> wrapTls + // SSL/TLS + val unwrapTls = b.add(Flow[SslTlsInbound] collect { case x: SessionBytes ⇒ x.bytes }) + val wrapTls = b.add(Flow[ByteString].map[SslTlsOutbound](SendBytes)) + unwrapTls ~> wsSwitchTokenMerge.in0 + protocolMerge.out ~> wrapTls - BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest]( - one2one.in2, - wrapTls.outlet, - unwrapTls.inlet, - one2one.out1) + BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest]( + one2one.in2, + wrapTls.outlet, + unwrapTls.inlet, + one2one.out1) }) } @@ -169,6 +169,7 @@ private[http] object HttpServerBluePrint { val shape = new BidiShape(requestParsingIn, requestPrepOut, httpResponseIn, responseCtxOut) def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { + val pullHttpResponseIn = () ⇒ pull(httpResponseIn) var openRequests = immutable.Queue[RequestStart]() var oneHundredContinueResponsePending = false var pullSuppressed = false @@ -218,7 +219,8 @@ private[http] object HttpServerBluePrint { requestStart.expect100Continue && oneHundredContinueResponsePending || isClosed(requestParsingIn) && openRequests.isEmpty || isEarlyResponse - push(responseCtxOut, ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close)) + emit(responseCtxOut, ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close), + pullHttpResponseIn) if (close) complete(responseCtxOut) } override def onUpstreamFinish() = @@ -244,9 +246,17 @@ private[http] object HttpServerBluePrint { } }) - setHandler(responseCtxOut, new OutHandler { - def onPull(): Unit = if (!hasBeenPulled(httpResponseIn)) pull(httpResponseIn) - override def onDownstreamFinish() = cancel(httpResponseIn) + class ResponseCtxOutHandler extends OutHandler { + override def onPull() = {} + override def onDownstreamFinish() = + cancel(httpResponseIn) // we cannot fully completeState() here as the websocket pipeline would not complete properly + } + setHandler(responseCtxOut, new ResponseCtxOutHandler { + override def onPull() = { + pull(httpResponseIn) + // after the initial pull here we only ever pull after having emitted in `onPush` of `httpResponseIn` + setHandler(responseCtxOut, new ResponseCtxOutHandler) + } }) def finishWithIllegalRequestError(status: StatusCode, info: ErrorInfo): Unit = { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala index 17557683c5..44b397b365 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala @@ -46,7 +46,7 @@ class One2OneBidiFlowSpec extends AkkaSpec with ConversionCheckedTripleEquals { outOut.expectError(new One2OneBidiFlow.UnexpectedOutputException(3)) } - "drop surplus output elements" in new Test() { + "fully propagate cancellation" in new Test() { inIn.sendNext(1) inOut.requestNext() should ===(1) @@ -55,6 +55,9 @@ class One2OneBidiFlowSpec extends AkkaSpec with ConversionCheckedTripleEquals { outOut.cancel() outIn.expectCancellation() + + inOut.cancel() + inIn.expectCancellation() } "backpressure the input side if the maximum number of pending output elements has been reached" in { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index dc50ff427a..34221bdd90 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -799,7 +799,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends * '''Completes when''' upstream completes * * '''Cancels when''' downstream cancels and substreams cancel - * + * * See also [[Flow.splitAfter]]. */ def splitWhen(p: function.Predicate[Out]): javadsl.Flow[In, Source[Out, Unit], Mat] = diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala index 7859ab9aff..61cbb3df0c 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala @@ -22,8 +22,6 @@ object One2OneBidiFlow { * for every input element. * 3. Backpressures the input side if the maximum number of pending output elements has been reached, * which is given via the ``maxPending`` parameter. You can use -1 to disable this feature. - * 4. Drops surplus output elements, i.e. ones that the inner flow tries to produce after the input stream - * has signalled completion. Note that no error is triggered in this case! */ def apply[I, O](maxPending: Int): BidiFlow[I, I, O, O, Unit] = BidiFlow.fromGraph(new One2OneBidi[I, O](maxPending)) @@ -39,7 +37,7 @@ object One2OneBidiFlow { override def createLogic(effectiveAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pending = 0 - private var pullsSuppressed = 0 + private var pullSuppressed = false setHandler(inIn, new InHandler { override def onPush(): Unit = { @@ -52,7 +50,7 @@ object One2OneBidiFlow { setHandler(inOut, new OutHandler { override def onPull(): Unit = if (pending < maxPending || maxPending == -1) pull(inIn) - else pullsSuppressed += 1 + else pullSuppressed = true override def onDownstreamFinish(): Unit = cancel(inIn) }) @@ -62,8 +60,8 @@ object One2OneBidiFlow { if (pending > 0) { pending -= 1 push(outOut, element) - if (pullsSuppressed > 0) { - pullsSuppressed -= 1 + if (pullSuppressed) { + pullSuppressed = false pull(inIn) } } else throw new UnexpectedOutputException(element) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index e5cdef8e17..2a0ae4c0c0 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -373,7 +373,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Signals that there will be no more elements emitted on the given port. */ - final protected def complete[T](out: Outlet[T]): Unit = interpreter.complete(conn(out)) + final protected def complete[T](out: Outlet[T]): Unit = + getHandler(out) match { + case e: Emitting[_] ⇒ e.addFollowUp(new EmittingCompletion(e.out, e.previous)) + case _ ⇒ interpreter.complete(conn(out)) + } /** * Signals failure through the given port.