diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index d6fd66c3d1..448f0944b2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -6,6 +6,7 @@ package akka.http.impl.engine.rendering import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader } import akka.http.scaladsl.model.ws.Message +import akka.stream.{ Outlet, Inlet, Attributes, FlowShape } import scala.annotation.tailrec import akka.event.LoggingAdapter @@ -52,170 +53,201 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser // split out so we can stabilize by overriding in tests protected def currentTimeMillis(): Long = System.currentTimeMillis() - def newRenderer: HttpResponseRenderer = new HttpResponseRenderer + def renderer: Flow[ResponseRenderingContext, ResponseRenderingOutput, Unit] = Flow.fromGraph(HttpResponseRenderer) - final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ResponseRenderingOutput, Any]] { + object HttpResponseRenderer extends GraphStage[FlowShape[ResponseRenderingContext, ResponseRenderingOutput]] { + val in = Inlet[ResponseRenderingContext]("in") + val out = Outlet[ResponseRenderingOutput]("out") + val shape: FlowShape[ResponseRenderingContext, ResponseRenderingOutput] = FlowShape(in, out) - private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response - private[this] def close: Boolean = closeMode != DontClose - private[this] def closeIf(cond: Boolean): Unit = - if (cond) closeMode = CloseConnection + def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response + private[this] def close: Boolean = closeMode != DontClose + private[this] def closeIf(cond: Boolean): Unit = + if (cond) closeMode = CloseConnection - // need this for testing - private[http] def isComplete = close + setHandler(in, new InHandler { + def onPush(): Unit = + render(grab(in)) match { + case Strict(outElement) ⇒ + push(out, outElement) + if (close) completeStage() + case Streamed(outStream) ⇒ transfer(outStream) + } - override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ResponseRenderingOutput, Any]]): SyncDirective = { - val r = new ByteStringRendering(responseHeaderSizeHint) - - import ctx.response._ - val noEntity = entity.isKnownEmpty || ctx.requestMethod == HttpMethods.HEAD - - def renderStatusLine(): Unit = - protocol match { - case `HTTP/1.1` ⇒ if (status eq StatusCodes.OK) r ~~ DefaultStatusLineBytes else r ~~ StatusLineStartBytes ~~ status ~~ CrLf - case `HTTP/1.0` ⇒ r ~~ protocol ~~ ' ' ~~ status ~~ CrLf + override def onUpstreamFinish(): Unit = closeMode = CloseConnection + }) + val waitForDemandHandler = new OutHandler { + def onPull(): Unit = if (close) completeStage() else pull(in) + } + setHandler(out, waitForDemandHandler) + def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = { + val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink") + sinkIn.setHandler(new InHandler { + def onPush(): Unit = push(out, sinkIn.grab()) + override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler) + }) + setHandler(out, new OutHandler { + def onPull(): Unit = sinkIn.pull() + }) + sinkIn.pull() + Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer) } - def render(h: HttpHeader) = r ~~ h ~~ CrLf + def render(ctx: ResponseRenderingContext): StrictOrStreamed = { + val r = new ByteStringRendering(responseHeaderSizeHint) - def mustRenderTransferEncodingChunkedHeader = - entity.isChunked && (!entity.isKnownEmpty || ctx.requestMethod == HttpMethods.HEAD) && (ctx.requestProtocol == `HTTP/1.1`) + import ctx.response._ + val noEntity = entity.isKnownEmpty || ctx.requestMethod == HttpMethods.HEAD - @tailrec def renderHeaders(remaining: List[HttpHeader], alwaysClose: Boolean = false, - connHeader: Connection = null, serverSeen: Boolean = false, - transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = - remaining match { - case head :: tail ⇒ head match { - case x: `Content-Length` ⇒ - suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + def renderStatusLine(): Unit = + protocol match { + case `HTTP/1.1` ⇒ if (status eq StatusCodes.OK) r ~~ DefaultStatusLineBytes else r ~~ StatusLineStartBytes ~~ status ~~ CrLf + case `HTTP/1.0` ⇒ r ~~ protocol ~~ ' ' ~~ status ~~ CrLf + } - case x: `Content-Type` ⇒ - suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpResponse.entity.contentType` instead.") - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + def render(h: HttpHeader) = r ~~ h ~~ CrLf - case x: Date ⇒ - render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true) + def mustRenderTransferEncodingChunkedHeader = + entity.isChunked && (!entity.isKnownEmpty || ctx.requestMethod == HttpMethods.HEAD) && (ctx.requestProtocol == `HTTP/1.1`) - case x: `Transfer-Encoding` ⇒ - x.withChunkedPeeled match { - case None ⇒ - suppressionWarning(log, head) + @tailrec def renderHeaders(remaining: List[HttpHeader], alwaysClose: Boolean = false, + connHeader: Connection = null, serverSeen: Boolean = false, + transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = + remaining match { + case head :: tail ⇒ head match { + case x: `Content-Length` ⇒ + suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + + case x: `Content-Type` ⇒ + suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpResponse.entity.contentType` instead.") + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + + case x: Date ⇒ + render(x) + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true) + + case x: `Transfer-Encoding` ⇒ + x.withChunkedPeeled match { + case None ⇒ + suppressionWarning(log, head) + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + case Some(te) ⇒ + // if the user applied some custom transfer-encoding we need to keep the header + render(if (mustRenderTransferEncodingChunkedHeader) te.withChunked else te) + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen = true, dateSeen) + } + + case x: Connection ⇒ + val connectionHeader = if (connHeader eq null) x else Connection(x.tokens ++ connHeader.tokens) + renderHeaders(tail, alwaysClose, connectionHeader, serverSeen, transferEncodingSeen, dateSeen) + + case x: Server ⇒ + render(x) + renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen) + + case x: CustomHeader ⇒ + if (!x.suppressRendering) render(x) + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + + case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") || + (x is "date") || (x is "server") || (x is "connection") ⇒ + suppressionWarning(log, x, "illegal RawHeader") + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + + case x ⇒ + render(x) renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) - case Some(te) ⇒ - // if the user applied some custom transfer-encoding we need to keep the header - render(if (mustRenderTransferEncodingChunkedHeader) te.withChunked else te) - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen = true, dateSeen) } - case x: Connection ⇒ - val connectionHeader = if (connHeader eq null) x else Connection(x.tokens ++ connHeader.tokens) - renderHeaders(tail, alwaysClose, connectionHeader, serverSeen, transferEncodingSeen, dateSeen) + case Nil ⇒ + if (!serverSeen) renderDefaultServerHeader(r) + if (!dateSeen) r ~~ dateHeader - case x: Server ⇒ - render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen) + // Do we close the connection after this response? + closeIf { + // if we are prohibited to keep-alive by the spec + alwaysClose || + // if the client wants to close and we don't override + (ctx.closeRequested && ((connHeader eq null) || !connHeader.hasKeepAlive)) || + // if the application wants to close explicitly + (protocol match { + case `HTTP/1.1` ⇒ (connHeader ne null) && connHeader.hasClose + case `HTTP/1.0` ⇒ if (connHeader eq null) ctx.requestProtocol == `HTTP/1.1` else !connHeader.hasKeepAlive + }) + } - case x: CustomHeader ⇒ - if (!x.suppressRendering) render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) + // Do we render an explicit Connection header? + val renderConnectionHeader = + protocol == `HTTP/1.0` && !close || protocol == `HTTP/1.1` && close || // if we don't follow the default behavior + close != ctx.closeRequested || // if we override the client's closing request + protocol != ctx.requestProtocol // if we reply with a mismatching protocol (let's be very explicit in this case) - case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") || - (x is "date") || (x is "server") || (x is "connection") ⇒ - suppressionWarning(log, x, "illegal RawHeader") - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) - - case x ⇒ - render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) - } - - case Nil ⇒ - if (!serverSeen) renderDefaultServerHeader(r) - if (!dateSeen) r ~~ dateHeader - - // Do we close the connection after this response? - closeIf { - // if we are prohibited to keep-alive by the spec - alwaysClose || - // if the client wants to close and we don't override - (ctx.closeRequested && ((connHeader eq null) || !connHeader.hasKeepAlive)) || - // if the application wants to close explicitly - (protocol match { - case `HTTP/1.1` ⇒ (connHeader ne null) && connHeader.hasClose - case `HTTP/1.0` ⇒ if (connHeader eq null) ctx.requestProtocol == `HTTP/1.1` else !connHeader.hasKeepAlive - }) + if (renderConnectionHeader) + r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf + else if (connHeader != null && connHeader.hasUpgrade) { + r ~~ connHeader ~~ CrLf + headers + .collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } + .foreach { header ⇒ closeMode = SwitchToWebsocket(header.handler) } + } + if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) + r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } - // Do we render an explicit Connection header? - val renderConnectionHeader = - protocol == `HTTP/1.0` && !close || protocol == `HTTP/1.1` && close || // if we don't follow the default behavior - close != ctx.closeRequested || // if we override the client's closing request - protocol != ctx.requestProtocol // if we reply with a mismatching protocol (let's be very explicit in this case) + def renderContentLengthHeader(contentLength: Long) = + if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r - if (renderConnectionHeader) - r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf - else if (connHeader != null && connHeader.hasUpgrade) { - r ~~ connHeader ~~ CrLf - headers - .collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } - .foreach { header ⇒ closeMode = SwitchToWebsocket(header.handler) } + def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] = + renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_)) + + def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed = + entity match { + case HttpEntity.Strict(_, data) ⇒ + renderHeaders(headers.toList) + renderEntityContentType(r, entity) + renderContentLengthHeader(data.length) ~~ CrLf + + if (!noEntity) r ~~ data + + Strict { + closeMode match { + case SwitchToWebsocket(handler) ⇒ ResponseRenderingOutput.SwitchToWebsocket(r.get, handler) + case _ ⇒ ResponseRenderingOutput.HttpData(r.get) + } + } + + case HttpEntity.Default(_, contentLength, data) ⇒ + renderHeaders(headers.toList) + renderEntityContentType(r, entity) + renderContentLengthHeader(contentLength) ~~ CrLf + Streamed(byteStrings(data.via(CheckContentLengthTransformer.flow(contentLength)))) + + case HttpEntity.CloseDelimited(_, data) ⇒ + renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD) + renderEntityContentType(r, entity) ~~ CrLf + Streamed(byteStrings(data)) + + case HttpEntity.Chunked(contentType, chunks) ⇒ + if (ctx.requestProtocol == `HTTP/1.0`) + completeResponseRendering(HttpEntity.CloseDelimited(contentType, chunks.map(_.data))) + else { + renderHeaders(headers.toList) + renderEntityContentType(r, entity) ~~ CrLf + Streamed(byteStrings(chunks.via(ChunkTransformer.flow))) + } } - if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) - r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf + + renderStatusLine() + completeResponseRendering(entity) } + } - def renderContentLengthHeader(contentLength: Long) = - if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r - - def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] = - renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_)) - - def completeResponseRendering(entity: ResponseEntity): Source[ResponseRenderingOutput, Any] = - entity match { - case HttpEntity.Strict(_, data) ⇒ - renderHeaders(headers.toList) - renderEntityContentType(r, entity) - renderContentLengthHeader(data.length) ~~ CrLf - - if (!noEntity) r ~~ data - - Source.single { - closeMode match { - case SwitchToWebsocket(handler) ⇒ ResponseRenderingOutput.SwitchToWebsocket(r.get, handler) - case _ ⇒ ResponseRenderingOutput.HttpData(r.get) - } - } - - case HttpEntity.Default(_, contentLength, data) ⇒ - renderHeaders(headers.toList) - renderEntityContentType(r, entity) - renderContentLengthHeader(contentLength) ~~ CrLf - byteStrings(data.via(CheckContentLengthTransformer.flow(contentLength))) - - case HttpEntity.CloseDelimited(_, data) ⇒ - renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD) - renderEntityContentType(r, entity) ~~ CrLf - byteStrings(data) - - case HttpEntity.Chunked(contentType, chunks) ⇒ - if (ctx.requestProtocol == `HTTP/1.0`) - completeResponseRendering(HttpEntity.CloseDelimited(contentType, chunks.map(_.data))) - else { - renderHeaders(headers.toList) - renderEntityContentType(r, entity) ~~ CrLf - byteStrings(chunks.via(ChunkTransformer.flow)) - } - } - - renderStatusLine() - val result = completeResponseRendering(entity) - if (close) - opCtx.pushAndFinish(result) - else - opCtx.push(result) - } + sealed trait StrictOrStreamed + case class Strict(bytes: ResponseRenderingOutput) extends StrictOrStreamed + case class Streamed(source: Source[ResponseRenderingOutput, Any]) extends StrictOrStreamed } sealed trait CloseMode diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index c4c8fb1bae..5fd3e3cf64 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -166,8 +166,7 @@ private[http] object HttpServerBluePrint { } Flow[ResponseRenderingContext] - .via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer")) - .flatMapConcat(ConstantFun.scalaIdentityFunction) + .via(responseRendererFactory.renderer.named("renderer")) .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorHandling(errorHandler)).named("errorLogger")) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala index e2b01621cf..9ba309c58e 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala @@ -19,6 +19,8 @@ import akka.stream.scaladsl._ import akka.stream.ActorMaterializer import HttpEntity._ +import scala.util.control.NonFatal + class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" akka.event-handlers = ["akka.testkit.TestEventListener"] @@ -583,17 +585,26 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] = equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx ⇒ - val renderer = newRenderer - val rendererOutputSource = Await.result(Source.single(ctx) - .transform(() ⇒ renderer).named("renderer") - .runWith(Sink.head), 1.second) - val future = - rendererOutputSource.grouped(1000).map( - _.map { + val (wasCompletedFuture, resultFuture) = + (Source.single(ctx) ++ Source.maybe[ResponseRenderingContext]) // never send upstream completion + .via(renderer.named("renderer")) + .map { case ResponseRenderingOutput.HttpData(bytes) ⇒ bytes case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ throw new IllegalStateException("Didn't expect websocket response") - }).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) - Await.result(future, 250.millis) -> renderer.isComplete + } + .groupedWithin(1000, 100.millis) + .viaMat(StreamUtils.identityFinishReporter[Seq[ByteString]])(Keep.right) + .toMat(Sink.head)(Keep.both).run() + + // we try to find out if the renderer has already flagged completion even without the upstream being completed + val wasCompleted = + try { + Await.ready(wasCompletedFuture, 100.millis) + true + } catch { + case NonFatal(_) ⇒ false + } + Await.result(resultFuture, 250.millis).reduceLeft(_ ++ _).utf8String -> wasCompleted } override def currentTimeMillis() = DateTime(2011, 8, 25, 9, 10, 29).clicks // provide a stable date for testing diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index a4f328ca79..fb30e1a411 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -315,7 +315,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * INTERNAL API */ - private[stream] def interpreter: GraphInterpreter = + private[akka] def interpreter: GraphInterpreter = if (_interpreter == null) throw new IllegalStateException("not yet initialized: only setHandler is allowed in GraphStageLogic constructor") else _interpreter