diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index ad924c435f..07c1c7f9ef 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -13,6 +13,7 @@ import akka.event.LoggingAdapter import akka.stream._ import akka.stream.scaladsl._ import akka.http.ClientConnectionSettings +import akka.http.scaladsl.Http import akka.http.scaladsl.model.headers.Host import akka.http.scaladsl.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse } import akka.http.impl.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory } @@ -23,9 +24,6 @@ import akka.http.impl.util._ * INTERNAL API */ private[http] object OutgoingConnectionBlueprint { - - type ClientShape = BidiShape[HttpRequest, SslTlsOutbound, SslTlsInbound, HttpResponse] - /* Stream Setup ============ @@ -45,7 +43,7 @@ private[http] object OutgoingConnectionBlueprint { */ def apply(hostHeader: Host, settings: ClientConnectionSettings, - log: LoggingAdapter): Graph[ClientShape, Unit] = { + log: LoggingAdapter): Http.ClientLayer = { import settings._ // the initial header parser we initially use for every connection, @@ -59,7 +57,7 @@ private[http] object OutgoingConnectionBlueprint { val requestRendering: Flow[HttpRequest, ByteString, Unit] = Flow[HttpRequest] .map(RequestRenderingContext(_, hostHeader)) - .via(Flow[RequestRenderingContext].transform(() ⇒ requestRendererFactory.newRenderer).named("renderer")) + .via(Flow[RequestRenderingContext].map(requestRendererFactory.renderToSource).named("renderer")) .flatten(FlattenStrategy.concat) val methodBypass = Flow[HttpRequest].map(_.method) @@ -76,7 +74,7 @@ private[http] object OutgoingConnectionBlueprint { case (MessageStartError(_, info), _) ⇒ throw IllegalResponseException(info) } - FlowGraph.partial() { implicit b ⇒ + BidiFlow() { implicit b ⇒ import FlowGraph.Implicits._ val methodBypassFanout = b.add(Broadcast[HttpRequest](2, eagerCancel = true)) val responseParsingMerge = b.add(new ResponseParsingMerge(rootParser)) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpRequestRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpRequestRendererFactory.scala index 7458cfde6e..812a5e6383 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpRequestRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpRequestRendererFactory.scala @@ -4,6 +4,8 @@ package akka.http.impl.engine.rendering +import akka.http.ClientConnectionSettings + import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.util.ByteString @@ -20,110 +22,133 @@ import headers._ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.`User-Agent`], requestHeaderSizeHint: Int, log: LoggingAdapter) { + import HttpRequestRendererFactory.RequestRenderingOutput - def newRenderer: HttpRequestRenderer = new HttpRequestRenderer + def renderToSource(ctx: RequestRenderingContext): Source[ByteString, Any] = render(ctx).byteStream - final class HttpRequestRenderer extends PushStage[RequestRenderingContext, Source[ByteString, Any]] { + def render(ctx: RequestRenderingContext): RequestRenderingOutput = { + val r = new ByteStringRendering(requestHeaderSizeHint) + import ctx.request._ - override def onPush(ctx: RequestRenderingContext, opCtx: Context[Source[ByteString, Any]]): SyncDirective = { - val r = new ByteStringRendering(requestHeaderSizeHint) - import ctx.request._ + def renderRequestLine(): Unit = { + r ~~ method ~~ ' ' + val rawRequestUriRendered = headers.exists { + case `Raw-Request-URI`(rawUri) ⇒ + r ~~ rawUri; true + case _ ⇒ false + } + if (!rawRequestUriRendered) UriRendering.renderUriWithoutFragment(r, uri, UTF8) + r ~~ ' ' ~~ protocol ~~ CrLf + } - def renderRequestLine(): Unit = { - r ~~ method ~~ ' ' - val rawRequestUriRendered = headers.exists { - case `Raw-Request-URI`(rawUri) ⇒ - r ~~ rawUri; true - case _ ⇒ false + def render(h: HttpHeader) = r ~~ h ~~ CrLf + + @tailrec def renderHeaders(remaining: List[HttpHeader], hostHeaderSeen: Boolean = false, + userAgentSeen: Boolean = false, transferEncodingSeen: Boolean = false): Unit = + remaining match { + case head :: tail ⇒ head match { + case x: `Content-Length` ⇒ + suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + + case x: `Content-Type` ⇒ + suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpRequest.entity.contentType` instead.") + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + + case x: `Transfer-Encoding` ⇒ + x.withChunkedPeeled match { + case None ⇒ + suppressionWarning(log, head) + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + case Some(te) ⇒ + // if the user applied some custom transfer-encoding we need to keep the header + render(if (entity.isChunked && !entity.isKnownEmpty) te.withChunked else te) + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen = true) + } + + case x: `Host` ⇒ + render(x) + renderHeaders(tail, hostHeaderSeen = true, userAgentSeen, transferEncodingSeen) + + case x: `User-Agent` ⇒ + render(x) + renderHeaders(tail, hostHeaderSeen, userAgentSeen = true, transferEncodingSeen) + + case x: `Raw-Request-URI` ⇒ // we never render this header + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + + case x: CustomHeader ⇒ + if (!x.suppressRendering) render(x) + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + + case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") || + (x is "host") || (x is "user-agent") ⇒ + suppressionWarning(log, x, "illegal RawHeader") + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + + case x ⇒ + render(x) + renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) } - if (!rawRequestUriRendered) UriRendering.renderUriWithoutFragment(r, uri, UTF8) - r ~~ ' ' ~~ protocol ~~ CrLf + + case Nil ⇒ + if (!hostHeaderSeen) r ~~ ctx.hostHeader ~~ CrLf + if (!userAgentSeen && userAgentHeader.isDefined) r ~~ userAgentHeader.get ~~ CrLf + if (entity.isChunked && !entity.isKnownEmpty && !transferEncodingSeen) + r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } - def render(h: HttpHeader) = r ~~ h ~~ CrLf + def renderContentLength(contentLength: Long) = + if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r - @tailrec def renderHeaders(remaining: List[HttpHeader], hostHeaderSeen: Boolean = false, - userAgentSeen: Boolean = false, transferEncodingSeen: Boolean = false): Unit = - remaining match { - case head :: tail ⇒ head match { - case x: `Content-Length` ⇒ - suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + def renderStreamed(body: Source[ByteString, Any]): RequestRenderingOutput = + RequestRenderingOutput.Streamed(renderByteStrings(r, body)) - case x: `Content-Type` ⇒ - suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpRequest.entity.contentType` instead.") - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + def completeRequestRendering(): RequestRenderingOutput = + entity match { + case x if x.isKnownEmpty ⇒ + renderContentLength(0) ~~ CrLf + RequestRenderingOutput.Strict(r.get) - case x: `Transfer-Encoding` ⇒ - x.withChunkedPeeled match { - case None ⇒ - suppressionWarning(log, head) - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) - case Some(te) ⇒ - // if the user applied some custom transfer-encoding we need to keep the header - render(if (entity.isChunked && !entity.isKnownEmpty) te.withChunked else te) - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen = true) - } + case HttpEntity.Strict(_, data) ⇒ + renderContentLength(data.length) ~~ CrLf + RequestRenderingOutput.Strict(r.get ++ data) - case x: `Host` ⇒ - render(x) - renderHeaders(tail, hostHeaderSeen = true, userAgentSeen, transferEncodingSeen) + case HttpEntity.Default(_, contentLength, data) ⇒ + renderContentLength(contentLength) ~~ CrLf + renderStreamed(data.via(CheckContentLengthTransformer.flow(contentLength))) - case x: `User-Agent` ⇒ - render(x) - renderHeaders(tail, hostHeaderSeen, userAgentSeen = true, transferEncodingSeen) + case HttpEntity.Chunked(_, chunks) ⇒ + r ~~ CrLf + renderStreamed(chunks.via(ChunkTransformer.flow)) + } - case x: `Raw-Request-URI` ⇒ // we never render this header - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) + renderRequestLine() + renderHeaders(headers.toList) + renderEntityContentType(r, entity) + completeRequestRendering() + } - case x: CustomHeader ⇒ - if (!x.suppressRendering) render(x) - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) - - case x: RawHeader if (x is "content-type") || (x is "content-length") || (x is "transfer-encoding") || - (x is "host") || (x is "user-agent") ⇒ - suppressionWarning(log, x, "illegal RawHeader") - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) - - case x ⇒ - render(x) - renderHeaders(tail, hostHeaderSeen, userAgentSeen, transferEncodingSeen) - } - - case Nil ⇒ - if (!hostHeaderSeen) r ~~ ctx.hostHeader ~~ CrLf - if (!userAgentSeen && userAgentHeader.isDefined) r ~~ userAgentHeader.get ~~ CrLf - if (entity.isChunked && !entity.isKnownEmpty && !transferEncodingSeen) - r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf - } - - def renderContentLength(contentLength: Long) = - if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r - - def completeRequestRendering(): Source[ByteString, Any] = - entity match { - case x if x.isKnownEmpty ⇒ - renderContentLength(0) ~~ CrLf - Source.single(r.get) - - case HttpEntity.Strict(_, data) ⇒ - renderContentLength(data.length) ~~ CrLf - Source.single(r.get ++ data) - - case HttpEntity.Default(_, contentLength, data) ⇒ - renderContentLength(contentLength) ~~ CrLf - renderByteStrings(r, data.via(CheckContentLengthTransformer.flow(contentLength))) - - case HttpEntity.Chunked(_, chunks) ⇒ - r ~~ CrLf - renderByteStrings(r, chunks.via(ChunkTransformer.flow)) - } - - renderRequestLine() - renderHeaders(headers.toList) - renderEntityContentType(r, entity) - opCtx.push(completeRequestRendering()) + def renderStrict(ctx: RequestRenderingContext): ByteString = + render(ctx) match { + case RequestRenderingOutput.Strict(bytes) ⇒ bytes + case _: RequestRenderingOutput.Streamed ⇒ + throw new IllegalArgumentException(s"Request entity was not Strict but ${ctx.request.entity.getClass.getSimpleName}") } +} + +private[http] object HttpRequestRendererFactory { + def renderStrict(ctx: RequestRenderingContext, settings: ClientConnectionSettings, log: LoggingAdapter): ByteString = + new HttpRequestRendererFactory(settings.userAgentHeader, settings.requestHeaderSizeHint, log).renderStrict(ctx) + + sealed trait RequestRenderingOutput { + def byteStream: Source[ByteString, Any] + } + object RequestRenderingOutput { + case class Strict(bytes: ByteString) extends RequestRenderingOutput { + def byteStream: Source[ByteString, Any] = Source.single(bytes) + } + case class Streamed(byteStream: Source[ByteString, Any]) extends RequestRenderingOutput } } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala index b66fa9ff8f..8479b457bd 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala @@ -210,7 +210,7 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E def clientLayer(hostHeader: Host, settings: ClientConnectionSettings, log: LoggingAdapter = system.log): ClientLayer = - BidiFlow.wrap(OutgoingConnectionBlueprint(hostHeader, settings, log)) + OutgoingConnectionBlueprint(hostHeader, settings, log) /** * Starts a new connection pool to the given host and configuration and returns a [[Flow]] which dispatches @@ -497,6 +497,7 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { //#client-layer /** + * The type of the client-side HTTP layer as a stand-alone BidiFlow * that can be put atop the TCP layer to form an HTTP client. * * {{{ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/RequestRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/RequestRendererSpec.scala index f020f78b4d..81c5e4cfc2 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/RequestRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/RequestRendererSpec.scala @@ -311,10 +311,7 @@ class RequestRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll def renderTo(expected: String): Matcher[HttpRequest] = equal(expected.stripMarginWithNewline("\r\n")).matcher[String] compose { request ⇒ - val renderer = newRenderer - val byteStringSource = Await.result(Source.single(RequestRenderingContext(request, Host(serverAddress))) - .transform(() ⇒ renderer).named("renderer") - .runWith(Sink.head), 1.second) + val byteStringSource = renderToSource(RequestRenderingContext(request, Host(serverAddress))) val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) Await.result(future, 250.millis) }