diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index 7863046e8b..693bd18931 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -62,13 +62,13 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response - private[this] def close: Boolean = closeMode != DontClose - private[this] def closeIf(cond: Boolean): Unit = - if (cond) closeMode = CloseConnection + var closeMode: CloseMode = DontClose // signals what to do after the current response + def close: Boolean = closeMode != DontClose + def closeIf(cond: Boolean): Unit = if (cond) closeMode = CloseConnection + var transferring = false setHandler(in, new InHandler { - def onPush(): Unit = + override def onPush(): Unit = render(grab(in)) match { case Strict(outElement) ⇒ push(out, outElement) @@ -76,23 +76,36 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser case Streamed(outStream) ⇒ transfer(outStream) } - override def onUpstreamFinish(): Unit = closeMode = CloseConnection + override def onUpstreamFinish(): Unit = + if (transferring) closeMode = CloseConnection + else completeStage() }) val waitForDemandHandler = new OutHandler { - def onPull(): Unit = if (close) completeStage() else pull(in) + def onPull(): Unit = pull(in) } setHandler(out, waitForDemandHandler) def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = { + transferring = true val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink") sinkIn.setHandler(new InHandler { - def onPush(): Unit = push(out, sinkIn.grab()) - override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler) + override def onPush(): Unit = push(out, sinkIn.grab()) + override def onUpstreamFinish(): Unit = + if (close) completeStage() + else { + transferring = false + setHandler(out, waitForDemandHandler) + if (isAvailable(out)) pull(in) + } }) setHandler(out, new OutHandler { - def onPull(): Unit = sinkIn.pull() + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + } }) sinkIn.pull() - Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer) + outStream.runWith(sinkIn.sink)(interpreter.subFusingMaterializer) } def render(ctx: ResponseRenderingContext): StrictOrStreamed = {