=htc #19352 HttpServerBluePrint: get rid of costly flatMapConcat on the rendering-side

This commit is contained in:
Johannes Rudolph 2016-01-05 12:17:18 +01:00
parent ced5aa7ddc
commit 1f00b0c1c0
4 changed files with 195 additions and 153 deletions

View file

@ -6,6 +6,7 @@ package akka.http.impl.engine.rendering
import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader }
import akka.http.scaladsl.model.ws.Message
import akka.stream.{ Outlet, Inlet, Attributes, FlowShape }
import scala.annotation.tailrec
import akka.event.LoggingAdapter
@ -52,19 +53,49 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
// split out so we can stabilize by overriding in tests
protected def currentTimeMillis(): Long = System.currentTimeMillis()
def newRenderer: HttpResponseRenderer = new HttpResponseRenderer
def renderer: Flow[ResponseRenderingContext, ResponseRenderingOutput, Unit] = Flow.fromGraph(HttpResponseRenderer)
final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ResponseRenderingOutput, Any]] {
object HttpResponseRenderer extends GraphStage[FlowShape[ResponseRenderingContext, ResponseRenderingOutput]] {
val in = Inlet[ResponseRenderingContext]("in")
val out = Outlet[ResponseRenderingOutput]("out")
val shape: FlowShape[ResponseRenderingContext, ResponseRenderingOutput] = FlowShape(in, out)
def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) {
private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response
private[this] def close: Boolean = closeMode != DontClose
private[this] def closeIf(cond: Boolean): Unit =
if (cond) closeMode = CloseConnection
// need this for testing
private[http] def isComplete = close
setHandler(in, new InHandler {
def onPush(): Unit =
render(grab(in)) match {
case Strict(outElement)
push(out, outElement)
if (close) completeStage()
case Streamed(outStream) transfer(outStream)
}
override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ResponseRenderingOutput, Any]]): SyncDirective = {
override def onUpstreamFinish(): Unit = closeMode = CloseConnection
})
val waitForDemandHandler = new OutHandler {
def onPull(): Unit = if (close) completeStage() else pull(in)
}
setHandler(out, waitForDemandHandler)
def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = {
val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink")
sinkIn.setHandler(new InHandler {
def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler)
})
setHandler(out, new OutHandler {
def onPull(): Unit = sinkIn.pull()
})
sinkIn.pull()
Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer)
}
def render(ctx: ResponseRenderingContext): StrictOrStreamed = {
val r = new ByteStringRendering(responseHeaderSizeHint)
import ctx.response._
@ -172,7 +203,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
def byteStrings(entityBytes: Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] =
renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_))
def completeResponseRendering(entity: ResponseEntity): Source[ResponseRenderingOutput, Any] =
def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed =
entity match {
case HttpEntity.Strict(_, data)
renderHeaders(headers.toList)
@ -181,7 +212,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
if (!noEntity) r ~~ data
Source.single {
Strict {
closeMode match {
case SwitchToWebsocket(handler) ResponseRenderingOutput.SwitchToWebsocket(r.get, handler)
case _ ResponseRenderingOutput.HttpData(r.get)
@ -192,12 +223,12 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
renderHeaders(headers.toList)
renderEntityContentType(r, entity)
renderContentLengthHeader(contentLength) ~~ CrLf
byteStrings(data.via(CheckContentLengthTransformer.flow(contentLength)))
Streamed(byteStrings(data.via(CheckContentLengthTransformer.flow(contentLength))))
case HttpEntity.CloseDelimited(_, data)
renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD)
renderEntityContentType(r, entity) ~~ CrLf
byteStrings(data)
Streamed(byteStrings(data))
case HttpEntity.Chunked(contentType, chunks)
if (ctx.requestProtocol == `HTTP/1.0`)
@ -205,19 +236,20 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
else {
renderHeaders(headers.toList)
renderEntityContentType(r, entity) ~~ CrLf
byteStrings(chunks.via(ChunkTransformer.flow))
Streamed(byteStrings(chunks.via(ChunkTransformer.flow)))
}
}
renderStatusLine()
val result = completeResponseRendering(entity)
if (close)
opCtx.pushAndFinish(result)
else
opCtx.push(result)
completeResponseRendering(entity)
}
}
sealed trait StrictOrStreamed
case class Strict(bytes: ResponseRenderingOutput) extends StrictOrStreamed
case class Streamed(source: Source[ResponseRenderingOutput, Any]) extends StrictOrStreamed
}
sealed trait CloseMode
case object DontClose extends CloseMode
case object CloseConnection extends CloseMode

View file

@ -166,8 +166,7 @@ private[http] object HttpServerBluePrint {
}
Flow[ResponseRenderingContext]
.via(Flow[ResponseRenderingContext].transform(() responseRendererFactory.newRenderer).named("renderer"))
.flatMapConcat(ConstantFun.scalaIdentityFunction)
.via(responseRendererFactory.renderer.named("renderer"))
.via(Flow[ResponseRenderingOutput].transform(() errorHandling(errorHandler)).named("errorLogger"))
}

View file

@ -19,6 +19,8 @@ import akka.stream.scaladsl._
import akka.stream.ActorMaterializer
import HttpEntity._
import scala.util.control.NonFatal
class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val testConf: Config = ConfigFactory.parseString("""
akka.event-handlers = ["akka.testkit.TestEventListener"]
@ -583,17 +585,26 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] =
equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx
val renderer = newRenderer
val rendererOutputSource = Await.result(Source.single(ctx)
.transform(() renderer).named("renderer")
.runWith(Sink.head), 1.second)
val future =
rendererOutputSource.grouped(1000).map(
_.map {
val (wasCompletedFuture, resultFuture) =
(Source.single(ctx) ++ Source.maybe[ResponseRenderingContext]) // never send upstream completion
.via(renderer.named("renderer"))
.map {
case ResponseRenderingOutput.HttpData(bytes) bytes
case _: ResponseRenderingOutput.SwitchToWebsocket throw new IllegalStateException("Didn't expect websocket response")
}).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
Await.result(future, 250.millis) -> renderer.isComplete
}
.groupedWithin(1000, 100.millis)
.viaMat(StreamUtils.identityFinishReporter[Seq[ByteString]])(Keep.right)
.toMat(Sink.head)(Keep.both).run()
// we try to find out if the renderer has already flagged completion even without the upstream being completed
val wasCompleted =
try {
Await.ready(wasCompletedFuture, 100.millis)
true
} catch {
case NonFatal(_) false
}
Await.result(resultFuture, 250.millis).reduceLeft(_ ++ _).utf8String -> wasCompleted
}
override def currentTimeMillis() = DateTime(2011, 8, 25, 9, 10, 29).clicks // provide a stable date for testing

View file

@ -315,7 +315,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
/**
* INTERNAL API
*/
private[stream] def interpreter: GraphInterpreter =
private[akka] def interpreter: GraphInterpreter =
if (_interpreter == null)
throw new IllegalStateException("not yet initialized: only setHandler is allowed in GraphStageLogic constructor")
else _interpreter