From 9dc474a10ac1887c9d280904974e91f4f9208299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Fri, 22 Jul 2016 14:07:41 +0200 Subject: [PATCH] Pre-fuse http server layer (#20990) * Ported the first pre-fuse part endre did in pr #1972 * Allow the same HttpServerBluePrint to materialize multiple times HttpRequestParser now behave like a proper GraphStage (with regards to materialization) HttpResponseParser is kept "weird" to limit scope of commit. * TestClient method to dump with http client and curl in parallel for comparison * Cleanup * tightening down what can be overriden * tightening down access modifiers * updates according to review * Better defaults for the test server * Ups. Don't listen to public interfaces in test server by default. --- .../test/scala/akka/util/ByteStringSpec.scala | 2 +- .../engine/parsing/HttpMessageParser.scala | 155 ++++----- .../engine/parsing/HttpRequestParser.scala | 297 ++++++++++-------- .../engine/parsing/HttpResponseParser.scala | 80 ++++- .../engine/server/HttpServerBluePrint.scala | 6 +- .../main/scala/akka/http/scaladsl/Http.scala | 86 +++-- .../engine/parsing/RequestParserSpec.scala | 4 +- .../scala/akka/http/scaladsl/TestClient.scala | 57 +++- .../scala/akka/http/scaladsl/TestServer.scala | 11 +- 9 files changed, 423 insertions(+), 275 deletions(-) diff --git a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala index 0a2894ccc0..b87de033f2 100644 --- a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala @@ -414,7 +414,7 @@ class ByteStringSpec extends WordSpec with Matchers with Checkers { def excerciseRecombining(xs: ByteString, from: Int, until: Int) = { val (tmp, c) = xs.splitAt(until) val (a, b) = tmp.splitAt(from) - (a ++ b ++ c) should ===(xs) + (a ++ b ++ c) should ===(xs) } "recombining - edge cases" in { excerciseRecombining(ByteStrings(Vector(ByteString1(Array[Byte](1)), ByteString1(Array[Byte](2)))), -2147483648, 112121212) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala index 46c997d238..bb6cd112ef 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala @@ -23,52 +23,35 @@ import akka.stream.{ Attributes, FlowShape, Inlet, Outlet } /** * INTERNAL API + * + * Common logic for http request and response message parsing */ -private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput]( - val settings: ParserSettings, - val headerParser: HttpHeaderParser) { self ⇒ - import HttpMessageParser._ - import settings._ +private[http] trait HttpMessageParser[Output >: MessageOutput <: ParserOutput] { - private[this] val result = new ListBuffer[Output] + import HttpMessageParser._ + + protected final val result = new ListBuffer[Output] private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0) private[this] var protocol: HttpProtocol = `HTTP/1.1` - private[this] var completionHandling: CompletionHandling = CompletionOk - private[this] var terminated = false + protected var completionHandling: CompletionHandling = CompletionOk + protected var terminated = false private[this] var lastSession: SSLSession = null // used to prevent having to recreate header on each message private[this] var tlsSessionInfoHeader: `Tls-Session-Info` = null - def initialHeaderBuffer: ListBuffer[HttpHeader] = + + protected def settings: ParserSettings + protected def headerParser: HttpHeaderParser + /** invoked if the specified protocol is unknown */ + protected def onBadProtocol(): Nothing + protected def parseMessage(input: ByteString, offset: Int): HttpMessageParser.StateResult + protected def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, + clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): HttpMessageParser.StateResult + + protected final def initialHeaderBuffer: ListBuffer[HttpHeader] = if (settings.includeTlsSessionInfoHeader && tlsSessionInfoHeader != null) ListBuffer(tlsSessionInfoHeader) else ListBuffer() - // Note that this GraphStage mutates the HttpMessageParser instance, use with caution. - val stage = new GraphStage[FlowShape[SessionBytes, Output]] { - val in: Inlet[SessionBytes] = Inlet("HttpMessageParser.in") - val out: Outlet[Output] = Outlet("HttpMessageParser.out") - override val shape: FlowShape[SessionBytes, Output] = FlowShape(in, out) - - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = - new GraphStageLogic(shape) with InHandler with OutHandler { - override def onPush(): Unit = handleParserOutput(self.parseSessionBytes(grab(in))) - override def onPull(): Unit = handleParserOutput(self.onPull()) - - override def onUpstreamFinish(): Unit = - if (self.onUpstreamFinish()) completeStage() - else if (isAvailable(out)) handleParserOutput(self.onPull()) - - private def handleParserOutput(output: Output): Unit = { - output match { - case StreamEnd ⇒ completeStage() - case NeedMoreData ⇒ pull(in) - case x ⇒ push(out, x) - } - } - - setHandlers(in, out, this) - } - } - final def parseSessionBytes(input: SessionBytes): Output = { if (input.session ne lastSession) { lastSession = input.session @@ -93,17 +76,17 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`") run(state) - onPull() + doPull() } - final def onPull(): Output = + protected final def doPull(): Output = if (result.nonEmpty) { val head = result.head result.remove(0) // faster than `ListBuffer::drop` head } else if (terminated) StreamEnd else NeedMoreData - final def onUpstreamFinish(): Boolean = { + protected final def shouldComplete(): Boolean = { completionHandling() match { case Some(x) ⇒ emit(x) case None ⇒ // nothing to do @@ -118,28 +101,24 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser catch { case NotEnoughDataException ⇒ continue(input, offset)(startNewMessage) } } - protected def parseMessage(input: ByteString, offset: Int): StateResult - - def parseProtocol(input: ByteString, cursor: Int): Int = { + protected final def parseProtocol(input: ByteString, cursor: Int): Int = { def c(ix: Int) = byteChar(input, cursor + ix) if (c(0) == 'H' && c(1) == 'T' && c(2) == 'T' && c(3) == 'P' && c(4) == '/' && c(5) == '1' && c(6) == '.') { protocol = c(7) match { case '0' ⇒ `HTTP/1.0` case '1' ⇒ `HTTP/1.1` - case _ ⇒ badProtocol + case _ ⇒ onBadProtocol } cursor + 8 - } else badProtocol + } else onBadProtocol } - def badProtocol: Nothing - - @tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = initialHeaderBuffer, - headerCount: Int = 0, ch: Option[Connection] = None, - clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None, - teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false, - hh: Boolean = false): StateResult = - if (headerCount < maxHeaderCount) { + @tailrec protected final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = initialHeaderBuffer, + headerCount: Int = 0, ch: Option[Connection] = None, + clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None, + teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false, + hh: Boolean = false): StateResult = + if (headerCount < settings.maxHeaderCount) { var lineEnd = 0 val resultHeader = try { @@ -182,19 +161,15 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser case h ⇒ parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, ch, clh, cth, teh, e100c, hh) } - } else failMessageStart(s"HTTP message contains more than the configured limit of $maxHeaderCount headers") + } else failMessageStart(s"HTTP message contains more than the configured limit of ${settings.maxHeaderCount} headers") // work-around for compiler complaining about non-tail-recursion if we inline this method - def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int, ch: Option[Connection], - clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - e100c: Boolean, hh: Boolean)(input: ByteString, lineStart: Int): StateResult = + private def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int, ch: Option[Connection], + clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], + e100c: Boolean, hh: Boolean)(input: ByteString, lineStart: Int): StateResult = parseHeaderLines(input, lineStart, headers, headerCount, ch, clh, cth, teh, e100c, hh) - def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, - clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult - - def parseFixedLengthBody( + protected final def parseFixedLengthBody( remainingBodyBytes: Long, isLastMessage: Boolean)(input: ByteString, bodyStart: Int): StateResult = { val remainingInputBytes = input.length - bodyStart @@ -213,7 +188,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser } else continue(input, bodyStart)(parseFixedLengthBody(remainingBodyBytes, isLastMessage)) } - def parseChunk(input: ByteString, offset: Int, isLastMessage: Boolean, totalBytesRead: Long): StateResult = { + protected final def parseChunk(input: ByteString, offset: Int, isLastMessage: Boolean, totalBytesRead: Long): StateResult = { @tailrec def parseTrailer(extension: String, lineStart: Int, headers: List[HttpHeader] = Nil, headerCount: Int = 0): StateResult = { var errorInfo: ErrorInfo = null @@ -230,9 +205,9 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser setCompletionHandling(CompletionOk) if (isLastMessage) terminate() else startNewMessage(input, lineEnd) - case header if headerCount < maxHeaderCount ⇒ + case header if headerCount < settings.maxHeaderCount ⇒ parseTrailer(extension, lineEnd, header :: headers, headerCount + 1) - case _ ⇒ failEntityStream(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers") + case _ ⇒ failEntityStream(s"Chunk trailer contains more than the configured limit of ${settings.maxHeaderCount} headers") } } else failEntityStream(errorInfo) } @@ -252,24 +227,24 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser } else parseTrailer(extension, cursor) @tailrec def parseChunkExtensions(chunkSize: Int, cursor: Int)(startIx: Int = cursor): StateResult = - if (cursor - startIx <= maxChunkExtLength) { + if (cursor - startIx <= settings.maxChunkExtLength) { def extension = asciiString(input, startIx, cursor) byteChar(input, cursor) match { case '\r' if byteChar(input, cursor + 1) == '\n' ⇒ parseChunkBody(chunkSize, extension, cursor + 2) case '\n' ⇒ parseChunkBody(chunkSize, extension, cursor + 1) case _ ⇒ parseChunkExtensions(chunkSize, cursor + 1)(startIx) } - } else failEntityStream(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters") + } else failEntityStream(s"HTTP chunk extension length exceeds configured limit of ${settings.maxChunkExtLength} characters") @tailrec def parseSize(cursor: Int, size: Long): StateResult = - if (size <= maxChunkSize) { + if (size <= settings.maxChunkSize) { byteChar(input, cursor) match { case c if CharacterClasses.HEXDIG(c) ⇒ parseSize(cursor + 1, size * 16 + CharUtils.hexValue(c)) case ';' if cursor > offset ⇒ parseChunkExtensions(size.toInt, cursor + 1)() case '\r' if cursor > offset && byteChar(input, cursor + 1) == '\n' ⇒ parseChunkBody(size.toInt, "", cursor + 2) case c ⇒ failEntityStream(s"Illegal character '${escape(c)}' in chunk start") } - } else failEntityStream(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes") + } else failEntityStream(s"HTTP chunk size exceeds the configured limit of ${settings.maxChunkSize} bytes") try parseSize(offset, 0) catch { @@ -277,9 +252,9 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser } } - def emit(output: Output): Unit = result += output + protected def emit(output: Output): Unit = result += output - def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = { + protected final def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = { state = math.signum(offset - input.length) match { case -1 ⇒ @@ -291,30 +266,30 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser done() } - def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = { + protected final def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = { state = next(_, 0) done() } - def failMessageStart(summary: String): StateResult = failMessageStart(summary, "") - def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail) - def failMessageStart(status: StatusCode): StateResult = failMessageStart(status, status.defaultMessage) - def failMessageStart(status: StatusCode, summary: String, detail: String = ""): StateResult = failMessageStart(status, ErrorInfo(summary, detail)) - def failMessageStart(status: StatusCode, info: ErrorInfo): StateResult = { + protected final def failMessageStart(summary: String): StateResult = failMessageStart(summary, "") + protected final def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail) + protected final def failMessageStart(status: StatusCode): StateResult = failMessageStart(status, status.defaultMessage) + protected final def failMessageStart(status: StatusCode, summary: String, detail: String = ""): StateResult = failMessageStart(status, ErrorInfo(summary, detail)) + protected final def failMessageStart(status: StatusCode, info: ErrorInfo): StateResult = { emit(MessageStartError(status, info)) setCompletionHandling(CompletionOk) terminate() } - def failEntityStream(summary: String): StateResult = failEntityStream(summary, "") - def failEntityStream(summary: String, detail: String): StateResult = failEntityStream(ErrorInfo(summary, detail)) - def failEntityStream(info: ErrorInfo): StateResult = { + protected final def failEntityStream(summary: String): StateResult = failEntityStream(summary, "") + protected final def failEntityStream(summary: String, detail: String): StateResult = failEntityStream(ErrorInfo(summary, detail)) + protected final def failEntityStream(info: ErrorInfo): StateResult = { emit(EntityStreamError(info)) setCompletionHandling(CompletionOk) terminate() } - def terminate(): StateResult = { + protected final def terminate(): StateResult = { terminated = true done() } @@ -325,19 +300,19 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser */ private def done(): StateResult = null // StateResult is a phantom type - def contentType(cth: Option[`Content-Type`]) = cth match { + protected final def contentType(cth: Option[`Content-Type`]) = cth match { case Some(x) ⇒ x.contentType case None ⇒ ContentTypes.`application/octet-stream` } - def emptyEntity(cth: Option[`Content-Type`]) = + protected final def emptyEntity(cth: Option[`Content-Type`]) = StrictEntityCreator(if (cth.isDefined) HttpEntity.empty(cth.get.contentType) else HttpEntity.Empty) - def strictEntity(cth: Option[`Content-Type`], input: ByteString, bodyStart: Int, - contentLength: Int) = + protected final def strictEntity(cth: Option[`Content-Type`], input: ByteString, bodyStart: Int, + contentLength: Int) = StrictEntityCreator(HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength))) - def defaultEntity[A <: ParserOutput](cth: Option[`Content-Type`], contentLength: Long) = + protected final def defaultEntity[A <: ParserOutput](cth: Option[`Content-Type`], contentLength: Long) = StreamedEntityCreator[A, UniversalEntity] { entityParts ⇒ val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes @@ -346,7 +321,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser HttpEntity.Default(contentType(cth), contentLength, HttpEntity.limitableByteSource(data)) } - def chunkedEntity[A <: ParserOutput](cth: Option[`Content-Type`]) = + protected final def chunkedEntity[A <: ParserOutput](cth: Option[`Content-Type`]) = StreamedEntityCreator[A, RequestEntity] { entityChunks ⇒ val chunks = entityChunks.collect { case EntityChunk(chunk) ⇒ chunk @@ -355,16 +330,20 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser HttpEntity.Chunked(contentType(cth), HttpEntity.limitableChunkSource(chunks)) } - def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = + protected final def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = teh.withChunkedPeeled match { case Some(x) ⇒ x :: headers case None ⇒ headers } - def setCompletionHandling(completionHandling: CompletionHandling): Unit = + protected final def setCompletionHandling(completionHandling: CompletionHandling): Unit = this.completionHandling = completionHandling + } +/** + * INTERNAL API + */ private[http] object HttpMessageParser { sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index cfb40519b5..227bc1bb53 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -5,7 +5,8 @@ package akka.http.impl.engine.parsing import java.lang.{ StringBuilder ⇒ JStringBuilder } -import scala.annotation.tailrec + +import scala.annotation.{ switch, tailrec } import akka.http.scaladsl.settings.ParserSettings import akka.util.ByteString import akka.http.impl.engine.ws.Handshake @@ -14,160 +15,196 @@ import akka.http.scaladsl.model._ import headers._ import StatusCodes._ import ParserOutput._ +import akka.stream.{ Attributes, FlowShape, Inlet, Outlet } +import akka.stream.TLSProtocol.SessionBytes +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } /** * INTERNAL API */ -private[http] class HttpRequestParser( - _settings: ParserSettings, +private[http] final class HttpRequestParser( + settings: ParserSettings, rawRequestUriHeader: Boolean, - _headerParser: HttpHeaderParser) - extends HttpMessageParser[RequestOutput](_settings, _headerParser) { + headerParser: HttpHeaderParser) + extends GraphStage[FlowShape[SessionBytes, RequestOutput]] { self ⇒ + import HttpMessageParser._ import settings._ - private[this] var method: HttpMethod = _ - private[this] var uri: Uri = _ - private[this] var uriBytes: Array[Byte] = _ + val in = Inlet[SessionBytes]("HttpRequestParser.in") + val out = Outlet[RequestOutput]("HttpRequestParser.out") - def createShallowCopy(): HttpRequestParser = - new HttpRequestParser(settings, rawRequestUriHeader, headerParser.createShallowCopy()) + val shape = FlowShape.of(in, out) - def parseMessage(input: ByteString, offset: Int): StateResult = { - var cursor = parseMethod(input, offset) - cursor = parseRequestTarget(input, cursor) - cursor = parseProtocol(input, cursor) - if (byteChar(input, cursor) == '\r' && byteChar(input, cursor + 1) == '\n') - parseHeaderLines(input, cursor + 2) - else badProtocol - } + override protected def initialAttributes: Attributes = Attributes.name("HttpRequestParser") - def parseMethod(input: ByteString, cursor: Int): Int = { - @tailrec def parseCustomMethod(ix: Int = 0, sb: JStringBuilder = new JStringBuilder(16)): Int = - if (ix < maxMethodLength) { - byteChar(input, cursor + ix) match { - case ' ' ⇒ - customMethods(sb.toString) match { - case Some(m) ⇒ - method = m - cursor + ix + 1 - case None ⇒ throw new ParsingException(NotImplemented, ErrorInfo("Unsupported HTTP method", sb.toString)) - } - case c ⇒ parseCustomMethod(ix + 1, sb.append(c)) + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with HttpMessageParser[RequestOutput] with InHandler with OutHandler { + + import HttpMessageParser._ + + override val settings = self.settings + override val headerParser = self.headerParser.createShallowCopy() + + private[this] var method: HttpMethod = _ + private[this] var uri: Uri = _ + private[this] var uriBytes: Array[Byte] = _ + + override def onPush(): Unit = handleParserOutput(parseSessionBytes(grab(in))) + override def onPull(): Unit = handleParserOutput(doPull()) + + override def onUpstreamFinish(): Unit = + if (super.shouldComplete()) completeStage() + else if (isAvailable(out)) handleParserOutput(doPull()) + + setHandlers(in, out, this) + + private def handleParserOutput(output: RequestOutput): Unit = { + output match { + case StreamEnd ⇒ completeStage() + case NeedMoreData ⇒ pull(in) + case x ⇒ push(out, x) + } + } + + override def parseMessage(input: ByteString, offset: Int): StateResult = { + var cursor = parseMethod(input, offset) + cursor = parseRequestTarget(input, cursor) + cursor = parseProtocol(input, cursor) + if (byteChar(input, cursor) == '\r' && byteChar(input, cursor + 1) == '\n') + parseHeaderLines(input, cursor + 2) + else onBadProtocol + } + + def parseMethod(input: ByteString, cursor: Int): Int = { + @tailrec def parseCustomMethod(ix: Int = 0, sb: JStringBuilder = new JStringBuilder(16)): Int = + if (ix < maxMethodLength) { + byteChar(input, cursor + ix) match { + case ' ' ⇒ + customMethods(sb.toString) match { + case Some(m) ⇒ + method = m + cursor + ix + 1 + case None ⇒ throw new ParsingException(NotImplemented, ErrorInfo("Unsupported HTTP method", sb.toString)) + } + case c ⇒ parseCustomMethod(ix + 1, sb.append(c)) + } + } else throw new ParsingException( + BadRequest, + ErrorInfo("Unsupported HTTP method", s"HTTP method too long (started with '${sb.toString}'). " + + "Increase `akka.http.server.parsing.max-method-length` to support HTTP methods with more characters.")) + + @tailrec def parseMethod(meth: HttpMethod, ix: Int = 1): Int = + if (ix == meth.value.length) + if (byteChar(input, cursor + ix) == ' ') { + method = meth + cursor + ix + 1 + } else parseCustomMethod() + else if (byteChar(input, cursor + ix) == meth.value.charAt(ix)) parseMethod(meth, ix + 1) + else parseCustomMethod() + + import HttpMethods._ + (byteChar(input, cursor): @switch) match { + case 'G' ⇒ parseMethod(GET) + case 'P' ⇒ byteChar(input, cursor + 1) match { + case 'O' ⇒ parseMethod(POST, 2) + case 'U' ⇒ parseMethod(PUT, 2) + case 'A' ⇒ parseMethod(PATCH, 2) + case _ ⇒ parseCustomMethod() } - } else throw new ParsingException( - BadRequest, - ErrorInfo("Unsupported HTTP method", s"HTTP method too long (started with '${sb.toString}'). " + - "Increase `akka.http.server.parsing.max-method-length` to support HTTP methods with more characters.")) - - @tailrec def parseMethod(meth: HttpMethod, ix: Int = 1): Int = - if (ix == meth.value.length) - if (byteChar(input, cursor + ix) == ' ') { - method = meth - cursor + ix + 1 - } else parseCustomMethod() - else if (byteChar(input, cursor + ix) == meth.value.charAt(ix)) parseMethod(meth, ix + 1) - else parseCustomMethod() - - import HttpMethods._ - byteChar(input, cursor) match { - case 'G' ⇒ parseMethod(GET) - case 'P' ⇒ byteChar(input, cursor + 1) match { - case 'O' ⇒ parseMethod(POST, 2) - case 'U' ⇒ parseMethod(PUT, 2) - case 'A' ⇒ parseMethod(PATCH, 2) + case 'D' ⇒ parseMethod(DELETE) + case 'H' ⇒ parseMethod(HEAD) + case 'O' ⇒ parseMethod(OPTIONS) + case 'T' ⇒ parseMethod(TRACE) + case 'C' ⇒ parseMethod(CONNECT) case _ ⇒ parseCustomMethod() } - case 'D' ⇒ parseMethod(DELETE) - case 'H' ⇒ parseMethod(HEAD) - case 'O' ⇒ parseMethod(OPTIONS) - case 'T' ⇒ parseMethod(TRACE) - case 'C' ⇒ parseMethod(CONNECT) - case _ ⇒ parseCustomMethod() } - } - def parseRequestTarget(input: ByteString, cursor: Int): Int = { - val uriStart = cursor - val uriEndLimit = cursor + maxUriLength + def parseRequestTarget(input: ByteString, cursor: Int): Int = { + val uriStart = cursor + val uriEndLimit = cursor + maxUriLength - @tailrec def findUriEnd(ix: Int = cursor): Int = - if (ix == input.length) throw NotEnoughDataException - else if (CharacterClasses.WSPCRLF(input(ix).toChar)) ix - else if (ix < uriEndLimit) findUriEnd(ix + 1) - else throw new ParsingException( - RequestUriTooLong, - s"URI length exceeds the configured limit of $maxUriLength characters") + @tailrec def findUriEnd(ix: Int = cursor): Int = + if (ix == input.length) throw NotEnoughDataException + else if (CharacterClasses.WSPCRLF(input(ix).toChar)) ix + else if (ix < uriEndLimit) findUriEnd(ix + 1) + else throw new ParsingException( + RequestUriTooLong, + s"URI length exceeds the configured limit of $maxUriLength characters") - val uriEnd = findUriEnd() - try { - uriBytes = input.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here? - uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode) // TODO ByteStringParserInput? - } catch { - case IllegalUriException(info) ⇒ throw new ParsingException(BadRequest, info) + val uriEnd = findUriEnd() + try { + uriBytes = input.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here? + uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode) // TODO ByteStringParserInput? + } catch { + case IllegalUriException(info) ⇒ throw new ParsingException(BadRequest, info) + } + uriEnd + 1 } - uriEnd + 1 - } - def badProtocol = throw new ParsingException(HTTPVersionNotSupported) + override def onBadProtocol() = throw new ParsingException(HTTPVersionNotSupported) - // http://tools.ietf.org/html/rfc7230#section-3.3 - def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, - clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = - if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { - def emitRequestStart( - createEntity: EntityCreator[RequestOutput, RequestEntity], - headers: List[HttpHeader] = headers) = { - val allHeaders0 = - if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers - else headers + // http://tools.ietf.org/html/rfc7230#section-3.3 + override def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, + clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = + if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { + def emitRequestStart( + createEntity: EntityCreator[RequestOutput, RequestEntity], + headers: List[HttpHeader] = headers) = { + val allHeaders0 = + if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers + else headers - val allHeaders = - if (method == HttpMethods.GET) { - Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match { - case Some(upgrade) ⇒ upgrade :: allHeaders0 - case None ⇒ allHeaders0 + val allHeaders = + if (method == HttpMethods.GET) { + Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match { + case Some(upgrade) ⇒ upgrade :: allHeaders0 + case None ⇒ allHeaders0 + } + } else allHeaders0 + + emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion)) + } + + teh match { + case None ⇒ + val contentLength = clh match { + case Some(`Content-Length`(len)) ⇒ len + case None ⇒ 0 + } + if (contentLength == 0) { + emitRequestStart(emptyEntity(cth)) + setCompletionHandling(HttpMessageParser.CompletionOk) + startNewMessage(input, bodyStart) + } else if (!method.isEntityAccepted) { + failMessageStart(UnprocessableEntity, s"${method.name} requests must not have an entity") + } else if (contentLength <= input.size - bodyStart) { + val cl = contentLength.toInt + emitRequestStart(strictEntity(cth, input, bodyStart, cl)) + setCompletionHandling(HttpMessageParser.CompletionOk) + startNewMessage(input, bodyStart + cl) + } else { + emitRequestStart(defaultEntity(cth, contentLength)) + parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) } - } else allHeaders0 - emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion)) - } - - teh match { - case None ⇒ - val contentLength = clh match { - case Some(`Content-Length`(len)) ⇒ len - case None ⇒ 0 - } - if (contentLength == 0) { - emitRequestStart(emptyEntity(cth)) - setCompletionHandling(HttpMessageParser.CompletionOk) - startNewMessage(input, bodyStart) - } else if (!method.isEntityAccepted) { + case Some(_) if !method.isEntityAccepted ⇒ failMessageStart(UnprocessableEntity, s"${method.name} requests must not have an entity") - } else if (contentLength <= input.size - bodyStart) { - val cl = contentLength.toInt - emitRequestStart(strictEntity(cth, input, bodyStart, cl)) - setCompletionHandling(HttpMessageParser.CompletionOk) - startNewMessage(input, bodyStart + cl) - } else { - emitRequestStart(defaultEntity(cth, contentLength)) - parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) - } - case Some(_) if !method.isEntityAccepted ⇒ - failMessageStart(UnprocessableEntity, s"${method.name} requests must not have an entity") + case Some(te) ⇒ + val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te) + if (te.isChunked) { + if (clh.isEmpty) { + emitRequestStart(chunkedEntity(cth), completedHeaders) + parseChunk(input, bodyStart, closeAfterResponseCompletion, totalBytesRead = 0L) + } else failMessageStart("A chunked request must not contain a Content-Length header.") + } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, + expect100continue, hostHeaderPresent, closeAfterResponseCompletion) + } + } else failMessageStart("Request is missing required `Host` header") - case Some(te) ⇒ - val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te) - if (te.isChunked) { - if (clh.isEmpty) { - emitRequestStart(chunkedEntity(cth), completedHeaders) - parseChunk(input, bodyStart, closeAfterResponseCompletion, totalBytesRead = 0L) - } else failMessageStart("A chunked request must not contain a Content-Length header.") - } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, - expect100continue, hostHeaderPresent, closeAfterResponseCompletion) - } - } else failMessageStart("Request is missing required `Host` header") + } + + override def toString: String = "HttpRequestParser" } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala index 2d65fe6c2d..ca8683f2b2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala @@ -13,12 +13,15 @@ import akka.util.ByteString import akka.http.scaladsl.model._ import headers._ import ParserOutput._ +import akka.stream.{ Attributes, FlowShape, Inlet, Outlet } +import akka.stream.TLSProtocol.SessionBytes +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } /** * INTERNAL API */ -private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: HttpHeaderParser) - extends HttpMessageParser[ResponseOutput](_settings, _headerParser) { +private[http] class HttpResponseParser(protected val settings: ParserSettings, protected val headerParser: HttpHeaderParser) + extends HttpMessageParser[ResponseOutput] { self ⇒ import HttpResponseParser._ import HttpMessageParser._ import settings._ @@ -26,31 +29,74 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: private[this] var contextForCurrentResponse: Option[ResponseContext] = None private[this] var statusCode: StatusCode = StatusCodes.OK - def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy()) + // Note that this GraphStage mutates the HttpMessageParser instance, use with caution. + final val stage = new GraphStage[FlowShape[SessionBytes, ResponseOutput]] { + val in: Inlet[SessionBytes] = Inlet("HttpResponseParser.in") + val out: Outlet[ResponseOutput] = Outlet("HttpResponseParser.out") + override val shape: FlowShape[SessionBytes, ResponseOutput] = FlowShape(in, out) - def setContextForNextResponse(responseContext: ResponseContext): Unit = + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPush(): Unit = handleParserOutput(self.parseSessionBytes(grab(in))) + override def onPull(): Unit = handleParserOutput(self.onPull()) + + override def onUpstreamFinish(): Unit = + if (self.onUpstreamFinish()) completeStage() + else if (isAvailable(out)) handleParserOutput(self.onPull()) + + private def handleParserOutput(output: ResponseOutput): Unit = { + output match { + case StreamEnd ⇒ completeStage() + case NeedMoreData ⇒ pull(in) + case x ⇒ push(out, x) + } + } + + setHandlers(in, out, this) + } + } + + final def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy()) + + final def setContextForNextResponse(responseContext: ResponseContext): Unit = if (contextForCurrentResponse.isEmpty) contextForCurrentResponse = Some(responseContext) - protected def parseMessage(input: ByteString, offset: Int): StateResult = + final def onPull(): ResponseOutput = + if (result.nonEmpty) { + val head = result.head + result.remove(0) // faster than `ListBuffer::drop` + head + } else if (terminated) StreamEnd else NeedMoreData + + final def onUpstreamFinish(): Boolean = { + completionHandling() match { + case Some(x) ⇒ emit(x) + case None ⇒ // nothing to do + } + terminated = true + result.isEmpty + } + + override final def emit(output: ResponseOutput): Unit = { + if (output == MessageEnd) contextForCurrentResponse = None + super.emit(output) + } + + override protected def parseMessage(input: ByteString, offset: Int): StateResult = if (contextForCurrentResponse.isDefined) { var cursor = parseProtocol(input, offset) if (byteChar(input, cursor) == ' ') { cursor = parseStatus(input, cursor + 1) parseHeaderLines(input, cursor) - } else badProtocol + } else onBadProtocol() } else { emit(NeedNextRequestMethod) continue(input, offset)(startNewMessage) } - override def emit(output: ResponseOutput): Unit = { - if (output == MessageEnd) contextForCurrentResponse = None - super.emit(output) - } + override final def onBadProtocol() = throw new ParsingException("The server-side HTTP version is not supported") - def badProtocol = throw new ParsingException("The server-side HTTP version is not supported") - - def parseStatus(input: ByteString, cursor: Int): Int = { + private def parseStatus(input: ByteString, cursor: Int): Int = { def badStatusCode = throw new ParsingException("Illegal response status code") def parseStatusCode() = { def intValue(offset: Int): Int = { @@ -84,9 +130,9 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: def handleInformationalResponses: Boolean = true // http://tools.ietf.org/html/rfc7230#section-3.3 - def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, - clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { + protected final def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, + clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { def emitResponseStart( createEntity: EntityCreator[ResponseOutput, ResponseEntity], @@ -161,7 +207,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: } else finishEmptyResponse() } - def parseToCloseBody(input: ByteString, bodyStart: Int, totalBytesRead: Long): StateResult = { + private def parseToCloseBody(input: ByteString, bodyStart: Int, totalBytesRead: Long): StateResult = { val newTotalBytes = totalBytesRead + math.max(0, input.length - bodyStart) if (input.length > bodyStart) emit(EntityPart(input.drop(bodyStart).compact)) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 9e6f78fd4a..bcc28bc828 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -228,11 +228,7 @@ private[http] object HttpServerBluePrint { case x ⇒ x } - Flow[SessionBytes].via( - // each connection uses a single (private) request parser instance for all its requests - // which builds a cache of all header instances seen on that connection - rootParser.createShallowCopy().stage).named("rootParser") - .map(establishAbsoluteUri) + Flow[SessionBytes].via(rootParser).map(establishAbsoluteUri) } def rendering(settings: ServerSettings, log: LoggingAdapter): Flow[ResponseRenderingContext, ResponseRenderingOutput, NotUsed] = { diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala index 248f70eb2d..f3b014a2dd 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala @@ -9,6 +9,7 @@ import java.util.concurrent.CompletionStage import javax.net.ssl._ import akka.actor._ +import akka.dispatch.ExecutionContexts import akka.event.{ Logging, LoggingAdapter } import akka.http.impl.engine.HttpConnectionTimeoutException import akka.http.impl.engine.client.PoolMasterActor.{ PoolSize, ShutdownAll } @@ -26,6 +27,7 @@ import akka.{ Done, NotUsed } import akka.stream._ import akka.stream.TLSProtocol._ import akka.stream.scaladsl._ +import akka.util.ByteString import com.typesafe.config.Config import com.typesafe.sslconfig.akka._ import com.typesafe.sslconfig.akka.util.AkkaLoggerFactory @@ -55,6 +57,27 @@ class HttpExt(private val config: Config)(implicit val system: ActorSystem) exte private[this] final val DefaultPortForProtocol = -1 // any negative value + private def fuseServerLayer(settings: ServerSettings, connectionContext: ConnectionContext, log: LoggingAdapter)(implicit mat: Materializer): BidiFlow[HttpResponse, ByteString, ByteString, HttpRequest, NotUsed] = { + val httpLayer = serverLayer(settings, None, log) + val tlsStage = sslTlsStage(connectionContext, Server) + BidiFlow.fromGraph(Fusing.aggressive(GraphDSL.create() { implicit b ⇒ + import GraphDSL.Implicits._ + val http = b.add(httpLayer) + val tls = b.add(tlsStage) + + val timeouts = b.add(Flow[ByteString].recover { + case t: TimeoutException ⇒ throw new HttpConnectionTimeoutException(t.getMessage) + }) + + tls.out2 ~> http.in2 + tls.in1 <~ http.out1 + + tls.out1 ~> timeouts.in + + BidiShape(http.in1, timeouts.out, tls.in2, http.out2) + })) + } + /** * Creates a [[akka.stream.scaladsl.Source]] of [[akka.http.scaladsl.Http.IncomingConnection]] instances which represents a prospective HTTP server binding * on the given `endpoint`. @@ -81,14 +104,14 @@ class HttpExt(private val config: Config)(implicit val system: ActorSystem) exte settings: ServerSettings = ServerSettings(system), log: LoggingAdapter = system.log)(implicit fm: Materializer): Source[IncomingConnection, Future[ServerBinding]] = { val effectivePort = if (port >= 0) port else connectionContext.defaultPort - val tlsStage = sslTlsStage(connectionContext, Server) + + val fullLayer = fuseServerLayer(settings, connectionContext, log) + val connections: Source[Tcp.IncomingConnection, Future[Tcp.ServerBinding]] = Tcp().bind(interface, effectivePort, settings.backlog, settings.socketOptions, halfClose = false, settings.timeouts.idleTimeout) connections.map { case Tcp.IncomingConnection(localAddress, remoteAddress, flow) ⇒ - val layer = serverLayer(settings, Some(remoteAddress), log) - val flowWithTimeoutRecovered = flow.via(MapError { case t: TimeoutException ⇒ new HttpConnectionTimeoutException(t.getMessage) }) - IncomingConnection(localAddress, remoteAddress, layer atop tlsStage join flowWithTimeoutRecovered) + IncomingConnection(localAddress, remoteAddress, fullLayer join flow) }.mapMaterializedValue { _.map(tcpBinding ⇒ ServerBinding(tcpBinding.localAddress)(() ⇒ tcpBinding.unbind()))(fm.executionContext) } @@ -110,30 +133,39 @@ class HttpExt(private val config: Config)(implicit val system: ActorSystem) exte connectionContext: ConnectionContext = defaultServerHttpContext, settings: ServerSettings = ServerSettings(system), log: LoggingAdapter = system.log)(implicit fm: Materializer): Future[ServerBinding] = { - def handleOneConnection(incomingConnection: IncomingConnection): Future[Done] = - try - incomingConnection.flow - .watchTermination()(Keep.right) - .joinMat(handler)(Keep.left) - .run() - catch { - case NonFatal(e) ⇒ - log.error(e, "Could not materialize handling flow for {}", incomingConnection) - throw e - } + val effectivePort = if (port >= 0) port else connectionContext.defaultPort + + val fullLayer: Flow[ByteString, ByteString, Future[Done]] = Flow.fromGraph(Fusing.aggressive( + Flow[HttpRequest] + .watchTermination()(Keep.right) + .viaMat(handler)(Keep.left) + .joinMat(fuseServerLayer(settings, connectionContext, log))(Keep.left))) + + val connections: Source[Tcp.IncomingConnection, Future[Tcp.ServerBinding]] = + Tcp().bind(interface, effectivePort, settings.backlog, settings.socketOptions, halfClose = false, settings.timeouts.idleTimeout) + + connections.mapAsyncUnordered(settings.maxConnections) { + case incoming: Tcp.IncomingConnection ⇒ + try { + fullLayer.addAttributes(HttpAttributes.remoteAddress(Some(incoming.remoteAddress))) + .joinMat(incoming.flow)(Keep.left) + .run().recover { + // Ignore incoming errors from the connection as they will cancel the binding. + // As far as it is known currently, these errors can only happen if a TCP error bubbles up + // from the TCP layer through the HTTP layer to the Http.IncomingConnection.flow. + // See https://github.com/akka/akka/issues/17992 + case NonFatal(ex) ⇒ + Done + }(ExecutionContexts.sameThreadExecutionContext) + } catch { + case NonFatal(e) ⇒ + log.error(e, "Could not materialize handling flow for {}", incoming) + throw e + } + }.mapMaterializedValue { + _.map(tcpBinding ⇒ ServerBinding(tcpBinding.localAddress)(() ⇒ tcpBinding.unbind()))(fm.executionContext) + }.to(Sink.ignore).run() - bind(interface, port, connectionContext, settings, log) - .mapAsyncUnordered(settings.maxConnections) { connection ⇒ - handleOneConnection(connection).recoverWith { - // Ignore incoming errors from the connection as they will cancel the binding. - // As far as it is known currently, these errors can only happen if a TCP error bubbles up - // from the TCP layer through the HTTP layer to the Http.IncomingConnection.flow. - // See https://github.com/akka/akka/issues/17992 - case NonFatal(_) ⇒ Future.successful(()) - }(fm.executionContext) - } - .to(Sink.ignore) - .run() } /** diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala index 17dc6ae6bc..c53f6cbdbf 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala @@ -300,7 +300,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "support `rawRequestUriHeader` setting" in new Test { override protected def newParser: HttpRequestParser = - new HttpRequestParser(parserSettings, rawRequestUriHeader = true, _headerParser = HttpHeaderParser(parserSettings)()) + new HttpRequestParser(parserSettings, rawRequestUriHeader = true, headerParser = HttpHeaderParser(parserSettings)()) """GET /f%6f%6fbar?q=b%61z HTTP/1.1 |Host: ping @@ -557,7 +557,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] = Source(input.toList) .map(bytes ⇒ SessionBytes(TLSPlacebo.dummySession, ByteString(bytes))) - .via(parser.stage).named("parser") + .via(parser).named("parser") .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .prefixAndTail(1) .collect { diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/TestClient.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/TestClient.scala index 771280ae72..7f593ef718 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/TestClient.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/TestClient.scala @@ -4,13 +4,22 @@ package akka.http.scaladsl +import java.io.File +import java.nio.file.spi.FileSystemProvider +import java.nio.file.{ FileSystem, Path } + import com.typesafe.config.{ Config, ConfigFactory } + import scala.util.{ Failure, Success } -import akka.actor.{ UnhandledMessage, ActorSystem } -import akka.stream.ActorMaterializer -import akka.stream.scaladsl.{ Sink, Source } +import akka.actor.{ ActorSystem, UnhandledMessage } +import akka.stream.{ ActorMaterializer, IOResult } +import akka.stream.scaladsl.{ FileIO, Sink, Source } import akka.http.scaladsl.model._ import akka.http.impl.util._ +import akka.util.ByteString + +import scala.concurrent.{ Await, Future } +import scala.concurrent.duration._ object TestClient extends App { val testConf: Config = ConfigFactory.parseString(""" @@ -62,5 +71,47 @@ object TestClient extends App { } } + // for gathering dumps of entity and headers from akka http client + // and curl in parallel to compare + def fetchAndStoreABunchOfUrlsWithHttpAndCurl(urls: Seq[String]): Unit = { + assert(urls.nonEmpty) + assert(new File("/tmp/client-dumps/").exists(), "you need to create /tmp/client-dumps/ before running") + + val testConf: Config = ConfigFactory.parseString(""" + akka.loglevel = DEBUG + akka.log-dead-letters = off + akka.io.tcp.trace-logging = off""") + implicit val system = ActorSystem("ServerTest", testConf) + implicit val fm = ActorMaterializer() + import system.dispatcher + + try { + val done = Future.traverse(urls.zipWithIndex) { + case (url, index) ⇒ + Http().singleRequest(HttpRequest(uri = url)).map { response ⇒ + + val path = new File(s"/tmp/client-dumps/akka-body-$index.dump").toPath + val headersPath = new File(s"/tmp/client-dumps/akka-headers-$index.dump").toPath + + import scala.sys.process._ + (s"""curl -D /tmp/client-dumps/curl-headers-$index.dump $url""" #> new File(s"/tmp/client-dumps/curl-body-$index.dump")).! + + val headers = Source(response.headers).map(header ⇒ ByteString(header.name + ": " + header.value + "\n")) + .runWith(FileIO.toPath(headersPath)) + + val body = response.entity.dataBytes + .runWith(FileIO.toPath(path)) + .map(res ⇒ (url, path, res)): Future[(String, Path, IOResult)] + + headers.flatMap(_ ⇒ body) + } + } + + println("Fetched urls: " + Await.result(done, 10.minutes)) + } finally { + Http().shutdownAllConnectionPools().flatMap(_ ⇒ system.terminate()) + } + } + def shutdown(): Unit = system.terminate() } \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/TestServer.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/TestServer.scala index 0176913e0d..071b46cd5b 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/TestServer.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/TestServer.scala @@ -13,7 +13,7 @@ import scala.concurrent.Await import akka.actor.ActorSystem import akka.http.scaladsl.model._ import akka.http.scaladsl.model.ws._ -import akka.stream.ActorMaterializer +import akka.stream._ import akka.stream.scaladsl.{ Source, Flow } import com.typesafe.config.{ ConfigFactory, Config } import HttpMethods._ @@ -23,10 +23,17 @@ object TestServer extends App { akka.loglevel = INFO akka.log-dead-letters = off akka.stream.materializer.debug.fuzzing-mode = off + akka.actor.serialize-creators = off + akka.actor.serialize-messages = off + akka.actor.default-dispatcher.throughput = 1000 """) implicit val system = ActorSystem("ServerTest", testConf) - implicit val fm = ActorMaterializer() + val settings = ActorMaterializerSettings(system) + .withFuzzing(false) + // .withSyncProcessingLimit(Int.MaxValue) + .withInputBuffer(128, 128) + implicit val fm = ActorMaterializer(settings) try { val binding = Http().bindAndHandleSync({ case req @ HttpRequest(GET, Uri.Path("/"), _, _, _) if req.header[UpgradeToWebSocket].isDefined ⇒