=htc convert request rendering from a stage to a plain function

This commit is contained in:
Johannes Rudolph 2015-08-07 12:48:52 +02:00
parent 5c653f0641
commit e34b86ba30
4 changed files with 121 additions and 100 deletions

View file

@ -13,6 +13,7 @@ import akka.event.LoggingAdapter
import akka.stream._
import akka.stream.scaladsl._
import akka.http.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse }
import akka.http.impl.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
@ -23,9 +24,6 @@ import akka.http.impl.util._
* INTERNAL API
*/
private[http] object OutgoingConnectionBlueprint {
type ClientShape = BidiShape[HttpRequest, SslTlsOutbound, SslTlsInbound, HttpResponse]
/*
Stream Setup
============
@ -45,7 +43,7 @@ private[http] object OutgoingConnectionBlueprint {
*/
def apply(hostHeader: Host,
settings: ClientConnectionSettings,
log: LoggingAdapter): Graph[ClientShape, Unit] = {
log: LoggingAdapter): Http.ClientLayer = {
import settings._
// the initial header parser we initially use for every connection,
@ -59,7 +57,7 @@ private[http] object OutgoingConnectionBlueprint {
val requestRendering: Flow[HttpRequest, ByteString, Unit] = Flow[HttpRequest]
.map(RequestRenderingContext(_, hostHeader))
.via(Flow[RequestRenderingContext].transform(() requestRendererFactory.newRenderer).named("renderer"))
.via(Flow[RequestRenderingContext].map(requestRendererFactory.renderToSource).named("renderer"))
.flatten(FlattenStrategy.concat)
val methodBypass = Flow[HttpRequest].map(_.method)
@ -76,7 +74,7 @@ private[http] object OutgoingConnectionBlueprint {
case (MessageStartError(_, info), _) throw IllegalResponseException(info)
}
FlowGraph.partial() { implicit b
BidiFlow() { implicit b
import FlowGraph.Implicits._
val methodBypassFanout = b.add(Broadcast[HttpRequest](2, eagerCancel = true))
val responseParsingMerge = b.add(new ResponseParsingMerge(rootParser))

View file

@ -4,6 +4,8 @@
package akka.http.impl.engine.rendering
import akka.http.ClientConnectionSettings
import scala.annotation.tailrec
import akka.event.LoggingAdapter
import akka.util.ByteString
@ -20,110 +22,133 @@ import headers._
private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.`User-Agent`],
requestHeaderSizeHint: Int,
log: LoggingAdapter) {
import HttpRequestRendererFactory.RequestRenderingOutput
def newRenderer: HttpRequestRenderer = new HttpRequestRenderer
def renderToSource(ctx: RequestRenderingContext): Source[ByteString, Any] = render(ctx).byteStream
final class HttpRequestRenderer extends PushStage[RequestRenderingContext, Source[ByteString, Any]] {
def render(ctx: RequestRenderingContext): RequestRenderingOutput = {
val r = new ByteStringRendering(requestHeaderSizeHint)
import ctx.request._
override def onPush(ctx: RequestRenderingContext, opCtx: Context[Source[ByteString, Any]]): SyncDirective = {
val r = new ByteStringRendering(requestHeaderSizeHint)
import ctx.request._
def renderRequestLine(): Unit = {
r ~~ method ~~ ' '
val rawRequestUriRendered = headers.exists {
case `Raw-Request-URI`(rawUri)
r ~~ rawUri; true
case _ false
}
if (!rawRequestUriRendered) UriRendering.renderUriWithoutFragment(r, uri, UTF8)
r ~~ ' ' ~~ protocol ~~ CrLf
}
def renderRequestLine(): Unit = {
r ~~ method ~~ ' '
val rawRequestUriRendered = headers.exists {
case `Raw-Request-URI`(rawUri)
r ~~ rawUri; true
case _ false
def render(h: HttpHeader) = r ~~ h ~~ CrLf
@tailrec def renderHeaders(remaining: List[HttpHeader], hostHeaderSeen: Boolean = false,
userAgentSeen: Boolean = false, transferEncodingSeen: Boolean = false): Unit =
remaining match {
case head :: tail head match {
case x: `Content-Length`
suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x: `Content-Type`
suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpRequest.entity.contentType` instead.")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x: `Transfer-Encoding`
x.withChunkedPeeled match {
case None
suppressionWarning(log, head)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case Some(te)
// if the user applied some custom transfer-encoding we need to keep the header
render(if (entity.isChunked && !entity.isKnownEmpty) te.withChunked else te)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen = true)
}
case x: `Host`
render(x)
renderHeaders(tail, hostHeaderSeen = true, userAgentSeen, transferEncodingSeen)
case x: `User-Agent`
render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen = true, transferEncodingSeen)
case x: `Raw-Request-URI` // we never render this header
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x: CustomHeader
if (!x.suppressRendering) render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") ||
(x is "host") || (x is "user-agent")
suppressionWarning(log, x, "illegal RawHeader")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x
render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
}
if (!rawRequestUriRendered) UriRendering.renderUriWithoutFragment(r, uri, UTF8)
r ~~ ' ' ~~ protocol ~~ CrLf
case Nil
if (!hostHeaderSeen) r ~~ ctx.hostHeader ~~ CrLf
if (!userAgentSeen && userAgentHeader.isDefined) r ~~ userAgentHeader.get ~~ CrLf
if (entity.isChunked && !entity.isKnownEmpty && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
}
def render(h: HttpHeader) = r ~~ h ~~ CrLf
def renderContentLength(contentLength: Long) =
if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
@tailrec def renderHeaders(remaining: List[HttpHeader], hostHeaderSeen: Boolean = false,
userAgentSeen: Boolean = false, transferEncodingSeen: Boolean = false): Unit =
remaining match {
case head :: tail head match {
case x: `Content-Length`
suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
def renderStreamed(body: Source[ByteString, Any]): RequestRenderingOutput =
RequestRenderingOutput.Streamed(renderByteStrings(r, body))
case x: `Content-Type`
suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpRequest.entity.contentType` instead.")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
def completeRequestRendering(): RequestRenderingOutput =
entity match {
case x if x.isKnownEmpty
renderContentLength(0) ~~ CrLf
RequestRenderingOutput.Strict(r.get)
case x: `Transfer-Encoding`
x.withChunkedPeeled match {
case None
suppressionWarning(log, head)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case Some(te)
// if the user applied some custom transfer-encoding we need to keep the header
render(if (entity.isChunked && !entity.isKnownEmpty) te.withChunked else te)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen = true)
}
case HttpEntity.Strict(_, data)
renderContentLength(data.length) ~~ CrLf
RequestRenderingOutput.Strict(r.get ++ data)
case x: `Host`
render(x)
renderHeaders(tail, hostHeaderSeen = true, userAgentSeen, transferEncodingSeen)
case HttpEntity.Default(_, contentLength, data)
renderContentLength(contentLength) ~~ CrLf
renderStreamed(data.via(CheckContentLengthTransformer.flow(contentLength)))
case x: `User-Agent`
render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen = true, transferEncodingSeen)
case HttpEntity.Chunked(_, chunks)
r ~~ CrLf
renderStreamed(chunks.via(ChunkTransformer.flow))
}
case x: `Raw-Request-URI` // we never render this header
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
renderRequestLine()
renderHeaders(headers.toList)
renderEntityContentType(r, entity)
completeRequestRendering()
}
case x: CustomHeader
if (!x.suppressRendering) render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") ||
(x is "host") || (x is "user-agent")
suppressionWarning(log, x, "illegal RawHeader")
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
case x
render(x)
renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen)
}
case Nil
if (!hostHeaderSeen) r ~~ ctx.hostHeader ~~ CrLf
if (!userAgentSeen && userAgentHeader.isDefined) r ~~ userAgentHeader.get ~~ CrLf
if (entity.isChunked && !entity.isKnownEmpty && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
}
def renderContentLength(contentLength: Long) =
if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
def completeRequestRendering(): Source[ByteString, Any] =
entity match {
case x if x.isKnownEmpty
renderContentLength(0) ~~ CrLf
Source.single(r.get)
case HttpEntity.Strict(_, data)
renderContentLength(data.length) ~~ CrLf
Source.single(r.get ++ data)
case HttpEntity.Default(_, contentLength, data)
renderContentLength(contentLength) ~~ CrLf
renderByteStrings(r, data.via(CheckContentLengthTransformer.flow(contentLength)))
case HttpEntity.Chunked(_, chunks)
r ~~ CrLf
renderByteStrings(r, chunks.via(ChunkTransformer.flow))
}
renderRequestLine()
renderHeaders(headers.toList)
renderEntityContentType(r, entity)
opCtx.push(completeRequestRendering())
def renderStrict(ctx: RequestRenderingContext): ByteString =
render(ctx) match {
case RequestRenderingOutput.Strict(bytes) bytes
case _: RequestRenderingOutput.Streamed
throw new IllegalArgumentException(s"Request entity was not Strict but ${ctx.request.entity.getClass.getSimpleName}")
}
}
private[http] object HttpRequestRendererFactory {
def renderStrict(ctx: RequestRenderingContext, settings: ClientConnectionSettings, log: LoggingAdapter): ByteString =
new HttpRequestRendererFactory(settings.userAgentHeader, settings.requestHeaderSizeHint, log).renderStrict(ctx)
sealed trait RequestRenderingOutput {
def byteStream: Source[ByteString, Any]
}
object RequestRenderingOutput {
case class Strict(bytes: ByteString) extends RequestRenderingOutput {
def byteStream: Source[ByteString, Any] = Source.single(bytes)
}
case class Streamed(byteStream: Source[ByteString, Any]) extends RequestRenderingOutput
}
}

View file

@ -210,7 +210,7 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E
def clientLayer(hostHeader: Host,
settings: ClientConnectionSettings,
log: LoggingAdapter = system.log): ClientLayer =
BidiFlow.wrap(OutgoingConnectionBlueprint(hostHeader, settings, log))
OutgoingConnectionBlueprint(hostHeader, settings, log)
/**
* Starts a new connection pool to the given host and configuration and returns a [[Flow]] which dispatches
@ -497,6 +497,7 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
//#client-layer
/**
* The type of the client-side HTTP layer as a stand-alone BidiFlow
* that can be put atop the TCP layer to form an HTTP client.
*
* {{{

View file

@ -311,10 +311,7 @@ class RequestRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
def renderTo(expected: String): Matcher[HttpRequest] =
equal(expected.stripMarginWithNewline("\r\n")).matcher[String] compose { request
val renderer = newRenderer
val byteStringSource = Await.result(Source.single(RequestRenderingContext(request, Host(serverAddress)))
.transform(() renderer).named("renderer")
.runWith(Sink.head), 1.second)
val byteStringSource = renderToSource(RequestRenderingContext(request, Host(serverAddress)))
val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
Await.result(future, 250.millis)
}