diff --git a/akka-http-core/src/main/scala/akka/http/Http.scala b/akka-http-core/src/main/scala/akka/http/Http.scala index 2100c3e9ac..66d1be6f27 100644 --- a/akka-http-core/src/main/scala/akka/http/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/Http.scala @@ -11,9 +11,9 @@ import org.reactivestreams.{ Publisher, Subscriber } import scala.collection.immutable import akka.io.Inet import akka.stream.MaterializerSettings -import akka.http.engine.client.{ HttpClientProcessor, ClientConnectionSettings } +import akka.http.engine.client.ClientConnectionSettings import akka.http.engine.server.ServerSettings -import akka.http.model.{ HttpResponse, HttpRequest, japi } +import akka.http.model.{ ErrorInfo, HttpResponse, HttpRequest, japi } import akka.http.util._ import akka.actor._ @@ -141,6 +141,8 @@ object Http extends ExtensionKey[HttpExt] { class ConnectionAttemptFailedException(val endpoint: InetSocketAddress) extends ConnectionException(s"Connection attempt to $endpoint failed") class RequestTimeoutException(val request: HttpRequest, message: String) extends ConnectionException(message) + + class StreamException(val info: ErrorInfo) extends RuntimeException(info.summary) } class HttpExt(system: ExtendedActorSystem) extends akka.io.IO.Extension { diff --git a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala index d4edaabab6..d36f4c07ee 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala @@ -5,18 +5,17 @@ package akka.http.engine.client import java.net.InetSocketAddress -import akka.util.ByteString - import scala.collection.immutable.Queue import akka.stream.scaladsl._ import akka.event.LoggingAdapter import akka.stream.FlowMaterializer import akka.stream.FlattenStrategy import akka.stream.io.StreamTcp +import akka.util.ByteString import akka.http.Http import akka.http.model.{ HttpMethod, HttpRequest, ErrorInfo, HttpResponse } import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory } -import akka.http.engine.parsing.HttpResponseParser +import akka.http.engine.parsing.{ HttpRequestParser, HttpHeaderParser, HttpResponseParser } import akka.http.engine.parsing.ParserOutput._ import akka.http.util._ @@ -29,10 +28,13 @@ private[http] class HttpClientPipeline(effectiveSettings: ClientConnectionSettin import effectiveSettings._ - val rootParser = new HttpResponseParser(parserSettings)() - val warnOnIllegalHeader: ErrorInfo ⇒ Unit = errorInfo ⇒ - if (parserSettings.illegalHeaderWarnings) - log.warning(errorInfo.withSummaryPrepended("Illegal response header").formatPretty) + // the initial header parser we initially use for every connection, + // will not be mutated, all "shared copy" parsers copy on first-write into the header cache + val rootParser = new HttpResponseParser( + parserSettings, + HttpHeaderParser(parserSettings) { errorInfo ⇒ + if (parserSettings.illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal response header").formatPretty) + }) val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log) @@ -60,7 +62,10 @@ private[http] class HttpClientPipeline(effectiveSettings: ClientConnectionSettin val responsePipeline = Flow[ByteString] - .transform("rootParser", () ⇒ rootParser.copyWith(warnOnIllegalHeader, requestMethodByPass)) + .transform("rootParser", () ⇒ + // each connection uses a single (private) response parser instance for all its responses + // which builds a cache of all header instances seen on that connection + rootParser.createShallowCopy(requestMethodByPass)) .splitWhen(_.isInstanceOf[MessageStart]) .headAndTail .collect { diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala index d7d33197af..03c2259b10 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala @@ -50,7 +50,11 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, // see: http://www.cgjennings.ca/fjs/ and http://ijes.info/4/1/42544103.pdf private[this] val boyerMoore = new BoyerMoore(needle) - private[this] val headerParser = HttpHeaderParser(settings, warnOnIllegalHeader) // TODO: prevent re-priming header parser from scratch + // TODO: prevent re-priming header parser from scratch + private[this] val headerParser = HttpHeaderParser(settings) { errorInfo ⇒ + if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty) + } + private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs private[this] var resultIterator: Iterator[Output] = Iterator.empty private[this] var state: ByteString ⇒ StateResult = tryParseInitialBoundary diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpHeaderParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpHeaderParser.scala index 5ab431b75b..f3dd992535 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpHeaderParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpHeaderParser.scala @@ -57,7 +57,7 @@ import akka.http.model.parser.CharacterClasses._ * Since we address them via the nodes MSB and zero is reserved the trie * cannot hold more then 255 items, so this array has a fixed size of 255. */ -private[parsing] final class HttpHeaderParser private ( +private[engine] final class HttpHeaderParser private ( val settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo ⇒ Unit, private[this] var nodes: Array[Char] = new Array(512), // initial size, can grow as needed @@ -83,7 +83,7 @@ private[parsing] final class HttpHeaderParser private ( /** * Returns a copy of this parser that shares the trie data with this instance. */ - def copyWith(warnOnIllegalHeader: ErrorInfo ⇒ Unit) = + def createShallowCopy(): HttpHeaderParser = new HttpHeaderParser(settings, warnOnIllegalHeader, nodes, nodeCount, branchData, branchDataCount, values, valueCount) /** @@ -402,12 +402,10 @@ private[http] object HttpHeaderParser { "Cache-Control: no-cache", "Expect: 100-continue") - private val defaultIllegalHeaderWarning: ErrorInfo ⇒ Unit = info ⇒ throw new IllegalHeaderException(info) - - def apply(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo ⇒ Unit = defaultIllegalHeaderWarning) = + def apply(settings: HttpHeaderParser.Settings)(warnOnIllegalHeader: ErrorInfo ⇒ Unit = info ⇒ throw new IllegalHeaderException(info)) = prime(unprimed(settings, warnOnIllegalHeader)) - def unprimed(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo ⇒ Unit = defaultIllegalHeaderWarning) = + def unprimed(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo ⇒ Unit) = new HttpHeaderParser(settings, warnOnIllegalHeader) def prime(parser: HttpHeaderParser): HttpHeaderParser = { diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala index 1d4feb82e1..7720e64847 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala @@ -10,48 +10,66 @@ import akka.parboiled2.CharUtils import akka.util.ByteString import akka.stream.scaladsl.Source import akka.stream.stage._ +import akka.http.Http.StreamException import akka.http.model.parser.CharacterClasses +import akka.http.util._ import akka.http.model._ import headers._ import HttpProtocols._ +import ParserOutput._ /** * INTERNAL API */ -private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOutput <: ParserOutput](val settings: ParserSettings, - val headerParser: HttpHeaderParser) - extends StatefulStage[ByteString, Output] { +private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings, + val headerParser: HttpHeaderParser) + extends PushPullStage[ByteString, Output] { + import HttpMessageParser._ import settings._ sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult - private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs + private[this] val result = new ListBuffer[Output] private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0) private[this] var protocol: HttpProtocol = `HTTP/1.1` + private[this] var completionHandling: CompletionHandling = CompletionOk private[this] var terminated = false - override def initial = new State { - override def onPush(input: ByteString, ctx: Context[Output]): Directive = { - result.clear() + override def onPush(input: ByteString, ctx: Context[Output]): Directive = { + @tailrec def run(next: ByteString ⇒ StateResult): StateResult = + (try next(input) + catch { + case e: ParsingException ⇒ failMessageStart(e.status, e.info) + case NotEnoughDataException ⇒ + // we are missing a try/catch{continue} wrapper somewhere + throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) + }) match { + case Trampoline(x) ⇒ run(x) + case x ⇒ x + } - @tailrec def run(next: ByteString ⇒ StateResult): StateResult = - (try next(input) - catch { - case e: ParsingException ⇒ fail(e.status, e.info) - case NotEnoughDataException ⇒ - // we are missing a try/catch{continue} wrapper somewhere - throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) - }) match { - case Trampoline(next) ⇒ run(next) - case x ⇒ x - } + if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`") + run(state) + pushResultHeadAndFinishOrPull(ctx) + } - run(state) - val resultIterator = result.iterator - if (terminated) emitAndFinish(resultIterator, ctx) - else emit(resultIterator, ctx) + def onPull(ctx: Context[Output]): Directive = pushResultHeadAndFinishOrPull(ctx) + + def pushResultHeadAndFinishOrPull(ctx: Context[Output]): Directive = + if (result.nonEmpty) { + val head = result.head + result.remove(0) // faster than `ListBuffer::drop` + ctx.push(head) + } else if (terminated) ctx.finish() else ctx.pull() + + override def onUpstreamFinish(ctx: Context[Output]) = { + completionHandling() match { + case Some(x) ⇒ emit(x.asInstanceOf[Output]) + case None ⇒ // nothing to do } + terminated = true + if (result.isEmpty) ctx.finish() else ctx.absorbTermination() } def startNewMessage(input: ByteString, offset: Int): StateResult = { @@ -59,6 +77,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut try parseMessage(input, offset) catch { case NotEnoughDataException ⇒ continue(input, offset)(_startNewMessage) } + if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError) _startNewMessage(input, offset) } @@ -81,65 +100,75 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut @tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: List[HttpHeader] = Nil, headerCount: Int = 0, ch: Option[Connection] = None, clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None, - teh: Option[`Transfer-Encoding`] = None, hh: Boolean = false): StateResult = { - var lineEnd = 0 - val resultHeader = - try { - lineEnd = headerParser.parseHeaderLine(input, lineStart)() - headerParser.resultHeader - } catch { - case NotEnoughDataException ⇒ null + teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false, + hh: Boolean = false): StateResult = + if (headerCount < maxHeaderCount) { + var lineEnd = 0 + val resultHeader = + try { + lineEnd = headerParser.parseHeaderLine(input, lineStart)() + headerParser.resultHeader + } catch { + case NotEnoughDataException ⇒ null + } + resultHeader match { + case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, ch, clh, cth, teh, e100c, hh)) + + case HttpHeaderParser.EmptyHeader ⇒ + val close = HttpMessage.connectionCloseExpected(protocol, ch) + setCompletionHandling(CompletionIsEntityStreamError) + parseEntity(headers, protocol, input, lineEnd, clh, cth, teh, e100c, hh, close) + + case h: `Content-Length` ⇒ clh match { + case None ⇒ parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, Some(h), cth, teh, e100c, hh) + case Some(`h`) ⇒ parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, e100c, hh) + case _ ⇒ failMessageStart("HTTP message must not contain more than one Content-Length header") + } + case h: `Content-Type` ⇒ cth match { + case None ⇒ parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, Some(h), teh, e100c, hh) + case Some(`h`) ⇒ parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, e100c, hh) + case _ ⇒ failMessageStart("HTTP message must not contain more than one Content-Type header") + } + case h: `Transfer-Encoding` ⇒ teh match { + case None ⇒ parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, cth, Some(h), e100c, hh) + case Some(x) ⇒ parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, Some(x append h.encodings), e100c, hh) + } + case h: Connection ⇒ ch match { + case None ⇒ parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, Some(h), clh, cth, teh, e100c, hh) + case Some(x) ⇒ parseHeaderLines(input, lineEnd, headers, headerCount, Some(x append h.tokens), clh, cth, teh, e100c, hh) + } + case h: Host ⇒ + if (!hh) parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c, hh = true) + else failMessageStart("HTTP message must not contain more than one Host header") + + case h: Expect ⇒ parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c = true, hh) + + case h ⇒ parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c, hh) } - resultHeader match { - case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, ch, clh, cth, teh, hh)) - - case HttpHeaderParser.EmptyHeader ⇒ - val close = HttpMessage.connectionCloseExpected(protocol, ch) - parseEntity(headers, protocol, input, lineEnd, clh, cth, teh, hh, close) - - case h: Connection ⇒ - parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, Some(h), clh, cth, teh, hh) - - case h: `Content-Length` ⇒ - if (clh.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, Some(h), cth, teh, hh) - else fail("HTTP message must not contain more than one Content-Length header") - - case h: `Content-Type` ⇒ - if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, Some(h), teh, hh) - else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, hh) - else fail("HTTP message must not contain more than one Content-Type header") - - case h: `Transfer-Encoding` ⇒ - parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, cth, Some(h), hh) - - case h if headerCount < maxHeaderCount ⇒ - parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, hh || h.isInstanceOf[Host]) - - case _ ⇒ fail(s"HTTP message contains more than the configured limit of $maxHeaderCount headers") - } - } + } else failMessageStart(s"HTTP message contains more than the configured limit of $maxHeaderCount headers") // work-around for compiler complaining about non-tail-recursion if we inline this method def parseHeaderLinesAux(headers: List[HttpHeader], headerCount: Int, ch: Option[Connection], clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - hh: Boolean)(input: ByteString, lineStart: Int): StateResult = - parseHeaderLines(input, lineStart, headers, headerCount, ch, clh, cth, teh, hh) + e100c: Boolean, hh: Boolean)(input: ByteString, lineStart: Int): StateResult = + parseHeaderLines(input, lineStart, headers, headerCount, ch, clh, cth, teh, e100c, hh) def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult def parseFixedLengthBody(remainingBodyBytes: Long, isLastMessage: Boolean)(input: ByteString, bodyStart: Int): StateResult = { val remainingInputBytes = input.length - bodyStart if (remainingInputBytes > 0) { if (remainingInputBytes < remainingBodyBytes) { - emit(ParserOutput.EntityPart(input drop bodyStart)) + emit(EntityPart(input drop bodyStart)) continue(parseFixedLengthBody(remainingBodyBytes - remainingInputBytes, isLastMessage)) } else { val offset = bodyStart + remainingBodyBytes.toInt - emit(ParserOutput.EntityPart(input.slice(bodyStart, offset))) - emit(ParserOutput.MessageEnd) + emit(EntityPart(input.slice(bodyStart, offset))) + emit(MessageEnd) + setCompletionHandling(CompletionOk) if (isLastMessage) terminate() else startNewMessage(input, offset) } @@ -149,32 +178,38 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut def parseChunk(input: ByteString, offset: Int, isLastMessage: Boolean): StateResult = { @tailrec def parseTrailer(extension: String, lineStart: Int, headers: List[HttpHeader] = Nil, headerCount: Int = 0): StateResult = { - val lineEnd = headerParser.parseHeaderLine(input, lineStart)() - headerParser.resultHeader match { - case HttpHeaderParser.EmptyHeader ⇒ - val lastChunk = - if (extension.isEmpty && headers.isEmpty) HttpEntity.LastChunk else HttpEntity.LastChunk(extension, headers) - emit(ParserOutput.EntityChunk(lastChunk)) - emit(ParserOutput.MessageEnd) - if (isLastMessage) terminate() - else startNewMessage(input, lineEnd) - case header if headerCount < maxHeaderCount ⇒ - parseTrailer(extension, lineEnd, header :: headers, headerCount + 1) - case _ ⇒ fail(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers") - } + var errorInfo: ErrorInfo = null + val lineEnd = + try headerParser.parseHeaderLine(input, lineStart)() + catch { case e: ParsingException ⇒ errorInfo = e.info; 0 } + if (errorInfo eq null) { + headerParser.resultHeader match { + case HttpHeaderParser.EmptyHeader ⇒ + val lastChunk = + if (extension.isEmpty && headers.isEmpty) HttpEntity.LastChunk else HttpEntity.LastChunk(extension, headers) + emit(EntityChunk(lastChunk)) + emit(MessageEnd) + setCompletionHandling(CompletionOk) + if (isLastMessage) terminate() + else startNewMessage(input, lineEnd) + case header if headerCount < maxHeaderCount ⇒ + parseTrailer(extension, lineEnd, header :: headers, headerCount + 1) + case _ ⇒ failEntityStream(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers") + } + } else failEntityStream(errorInfo) } def parseChunkBody(chunkSize: Int, extension: String, cursor: Int): StateResult = if (chunkSize > 0) { val chunkBodyEnd = cursor + chunkSize def result(terminatorLen: Int) = { - emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) + emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) trampoline(_ ⇒ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) } byteChar(input, chunkBodyEnd) match { case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' ⇒ result(2) case '\n' ⇒ result(1) - case x ⇒ fail("Illegal chunk termination") + case x ⇒ failEntityStream("Illegal chunk termination") } } else parseTrailer(extension, cursor) @@ -186,7 +221,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut case '\n' ⇒ parseChunkBody(chunkSize, extension, cursor + 1) case _ ⇒ parseChunkExtensions(chunkSize, cursor + 1)(startIx) } - } else fail(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters") + } else failEntityStream(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters") @tailrec def parseSize(cursor: Int, size: Long): StateResult = if (size <= maxChunkSize) { @@ -194,9 +229,9 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut case c if CharacterClasses.HEXDIG(c) ⇒ parseSize(cursor + 1, size * 16 + CharUtils.hexValue(c)) case ';' if cursor > offset ⇒ parseChunkExtensions(size.toInt, cursor + 1)() case '\r' if cursor > offset && byteChar(input, cursor + 1) == '\n' ⇒ parseChunkBody(size.toInt, "", cursor + 2) - case c ⇒ fail(s"Illegal character '${escape(c)}' in chunk start") + case c ⇒ failEntityStream(s"Illegal character '${escape(c)}' in chunk start") } - } else fail(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes") + } else failEntityStream(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes") try parseSize(offset, 0) catch { @@ -222,12 +257,21 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut } def trampoline(next: ByteString ⇒ StateResult): StateResult = Trampoline(next) - def fail(summary: String): StateResult = fail(summary, "") - def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail) - def fail(status: StatusCode): StateResult = fail(status, status.defaultMessage) - def fail(status: StatusCode, summary: String, detail: String = ""): StateResult = fail(status, ErrorInfo(summary, detail)) - def fail(status: StatusCode, info: ErrorInfo): StateResult = { - emit(ParserOutput.ParseError(status, info)) + def failMessageStart(summary: String): StateResult = failMessageStart(summary, "") + def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail) + def failMessageStart(status: StatusCode): StateResult = failMessageStart(status, status.defaultMessage) + def failMessageStart(status: StatusCode, summary: String, detail: String = ""): StateResult = failMessageStart(status, ErrorInfo(summary, detail)) + def failMessageStart(status: StatusCode, info: ErrorInfo): StateResult = { + emit(MessageStartError(status, info)) + setCompletionHandling(CompletionOk) + terminate() + } + + def failEntityStream(summary: String): StateResult = failEntityStream(summary, "") + def failEntityStream(summary: String, detail: String): StateResult = failEntityStream(ErrorInfo(summary, detail)) + def failEntityStream(info: ErrorInfo): StateResult = { + emit(EntityStreamError(info)) + setCompletionHandling(CompletionOk) terminate() } @@ -250,14 +294,23 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut contentLength: Int)(entityParts: Any): UniversalEntity = HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength)) - def defaultEntity(cth: Option[`Content-Type`], contentLength: Long)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = { - val data = entityParts.collect { case ParserOutput.EntityPart(bytes) ⇒ bytes } - HttpEntity.Default(contentType(cth), contentLength, data) + def defaultEntity(cth: Option[`Content-Type`], + contentLength: Long, + transformData: Source[ByteString] ⇒ Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = { + val data = entityParts.collect { + case EntityPart(bytes) ⇒ bytes + case EntityStreamError(info) ⇒ throw new StreamException(info) + } + HttpEntity.Default(contentType(cth), contentLength, transformData(data)) } - def chunkedEntity(cth: Option[`Content-Type`])(entityChunks: Source[_ <: ParserOutput]): RequestEntity with ResponseEntity = { - val chunks = entityChunks.collect { case ParserOutput.EntityChunk(chunk) ⇒ chunk } - HttpEntity.Chunked(contentType(cth), chunks) + def chunkedEntity(cth: Option[`Content-Type`], + transformChunks: Source[HttpEntity.ChunkStreamPart] ⇒ Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = { + val chunks = entityChunks.collect { + case EntityChunk(chunk) ⇒ chunk + case EntityStreamError(info) ⇒ throw new StreamException(info) + } + HttpEntity.Chunked(contentType(cth), transformChunks(chunks)) } def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = @@ -265,4 +318,16 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut case Some(x) ⇒ x :: headers case None ⇒ headers } + + def setCompletionHandling(completionHandling: CompletionHandling): Unit = + this.completionHandling = completionHandling } + +private[http] object HttpMessageParser { + type CompletionHandling = () ⇒ Option[ParserOutput] + val CompletionOk: CompletionHandling = () ⇒ None + val CompletionIsMessageStartError: CompletionHandling = + () ⇒ Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start"))) + val CompletionIsEntityStreamError: CompletionHandling = + () ⇒ Some(ParserOutput.EntityStreamError(ErrorInfo("Entity stream truncation"))) +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala index 9c25a3e950..5ffc57d77d 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala @@ -6,27 +6,34 @@ package akka.http.engine.parsing import java.lang.{ StringBuilder ⇒ JStringBuilder } import scala.annotation.tailrec -import akka.http.model.parser.CharacterClasses +import akka.actor.ActorRef +import akka.stream.stage.{ Context, PushPullStage } import akka.stream.scaladsl.Source import akka.util.ByteString +import akka.http.engine.server.OneHundredContinue +import akka.http.model.parser.CharacterClasses +import akka.http.util.identityFunc import akka.http.model._ import headers._ import StatusCodes._ +import ParserOutput._ /** * INTERNAL API */ private[http] class HttpRequestParser(_settings: ParserSettings, - rawRequestUriHeader: Boolean)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings)) - extends HttpMessageParser[ParserOutput.RequestOutput](_settings, _headerParser) { + rawRequestUriHeader: Boolean, + _headerParser: HttpHeaderParser, + oneHundredContinueRef: () ⇒ Option[ActorRef] = () ⇒ None) + extends HttpMessageParser[RequestOutput](_settings, _headerParser) { import settings._ private[this] var method: HttpMethod = _ private[this] var uri: Uri = _ private[this] var uriBytes: Array[Byte] = _ - def copyWith(warnOnIllegalHeader: ErrorInfo ⇒ Unit): HttpRequestParser = - new HttpRequestParser(settings, rawRequestUriHeader)(headerParser.copyWith(warnOnIllegalHeader)) + def createShallowCopy(oneHundredContinueRef: () ⇒ Option[ActorRef]): HttpRequestParser = + new HttpRequestParser(settings, rawRequestUriHeader, headerParser.createShallowCopy(), oneHundredContinueRef) def parseMessage(input: ByteString, offset: Int): StateResult = { var cursor = parseMethod(input, offset) @@ -107,16 +114,32 @@ private[http] class HttpRequestParser(_settings: ParserSettings, // http://tools.ietf.org/html/rfc7230#section-3.3 def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { - def emitRequestStart(createEntity: Source[ParserOutput.RequestOutput] ⇒ RequestEntity, + def emitRequestStart(createEntity: Source[RequestOutput] ⇒ RequestEntity, headers: List[HttpHeader] = headers) = { val allHeaders = if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers else headers - emit(ParserOutput.RequestStart(method, uri, protocol, allHeaders, createEntity, closeAfterResponseCompletion)) + emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion)) } + def expect100continueHandling[T]: Source[T] ⇒ Source[T] = + if (expect100continue) { + _.transform("expect100continueTrigger", () ⇒ new PushPullStage[T, T] { + private var oneHundredContinueSent = false + def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + def onPull(ctx: Context[T]) = { + if (!oneHundredContinueSent) { + val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable")) + ref ! OneHundredContinue + oneHundredContinueSent = true + } + ctx.pull() + } + }) + } else identityFunc + teh match { case None ⇒ val contentLength = clh match { @@ -124,17 +147,19 @@ private[http] class HttpRequestParser(_settings: ParserSettings, case None ⇒ 0 } if (contentLength > maxContentLength) - fail(RequestEntityTooLarge, + failMessageStart(RequestEntityTooLarge, s"Request Content-Length $contentLength exceeds the configured limit of $maxContentLength") else if (contentLength == 0) { emitRequestStart(emptyEntity(cth)) + setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart) } else if (contentLength <= input.size - bodyStart) { val cl = contentLength.toInt emitRequestStart(strictEntity(cth, input, bodyStart, cl)) + setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart + cl) } else { - emitRequestStart(defaultEntity(cth, contentLength)) + emitRequestStart(defaultEntity(cth, contentLength, expect100continueHandling)) parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) } @@ -142,11 +167,11 @@ private[http] class HttpRequestParser(_settings: ParserSettings, val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te) if (te.isChunked) { if (clh.isEmpty) { - emitRequestStart(chunkedEntity(cth), completedHeaders) + emitRequestStart(chunkedEntity(cth, expect100continueHandling), completedHeaders) parseChunk(input, bodyStart, closeAfterResponseCompletion) - } else fail("A chunked request must not contain a Content-Length header.") - } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, hostHeaderPresent, - closeAfterResponseCompletion) + } else failMessageStart("A chunked request must not contain a Content-Length header.") + } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, + expect100continue, hostHeaderPresent, closeAfterResponseCompletion) } - } else fail("Request is missing required `Host` header") + } else failMessageStart("Request is missing required `Host` header") } \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala index 09ba3a58ee..86629da8b2 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala @@ -11,20 +11,22 @@ import akka.util.ByteString import akka.http.model._ import headers._ import HttpResponseParser.NoMethod +import ParserOutput._ /** * INTERNAL API */ private[http] class HttpResponseParser(_settings: ParserSettings, - dequeueRequestMethodForNextResponse: () ⇒ HttpMethod = () ⇒ NoMethod)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings)) - extends HttpMessageParser[ParserOutput.ResponseOutput](_settings, _headerParser) { + _headerParser: HttpHeaderParser, + dequeueRequestMethodForNextResponse: () ⇒ HttpMethod = () ⇒ NoMethod) + extends HttpMessageParser[ResponseOutput](_settings, _headerParser) { import settings._ private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod private[this] var statusCode: StatusCode = StatusCodes.OK - def copyWith(warnOnIllegalHeader: ErrorInfo ⇒ Unit, dequeueRequestMethodForNextResponse: () ⇒ HttpMethod): HttpResponseParser = - new HttpResponseParser(settings, dequeueRequestMethodForNextResponse)(headerParser.copyWith(warnOnIllegalHeader)) + def createShallowCopy(dequeueRequestMethodForNextResponse: () ⇒ HttpMethod): HttpResponseParser = + new HttpResponseParser(settings, headerParser.createShallowCopy(), dequeueRequestMethodForNextResponse) override def startNewMessage(input: ByteString, offset: Int): StateResult = { requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse() @@ -39,7 +41,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, cursor = parseReason(input, cursor)() parseHeaderLines(input, cursor) } else badProtocol - } else fail("Unexpected server response", input.drop(offset).utf8String) + } else failMessageStart("Unexpected server response", input.drop(offset).utf8String) def badProtocol = throw new ParsingException("The server-side HTTP version is not supported") @@ -72,12 +74,13 @@ private[http] class HttpResponseParser(_settings: ParserSettings, // http://tools.ietf.org/html/rfc7230#section-3.3 def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], - hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { - def emitResponseStart(createEntity: Source[ParserOutput.ResponseOutput] ⇒ ResponseEntity, + expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { + def emitResponseStart(createEntity: Source[ResponseOutput] ⇒ ResponseEntity, headers: List[HttpHeader] = headers) = - emit(ParserOutput.ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion)) + emit(ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion)) def finishEmptyResponse() = { emitResponseStart(emptyEntity(cth)) + setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart) } @@ -86,11 +89,12 @@ private[http] class HttpResponseParser(_settings: ParserSettings, case None ⇒ clh match { case Some(`Content-Length`(contentLength)) ⇒ if (contentLength > maxContentLength) - fail(s"Response Content-Length $contentLength exceeds the configured limit of $maxContentLength") + failMessageStart(s"Response Content-Length $contentLength exceeds the configured limit of $maxContentLength") else if (contentLength == 0) finishEmptyResponse() else if (contentLength < input.size - bodyStart) { val cl = contentLength.toInt emitResponseStart(strictEntity(cth, input, bodyStart, cl)) + setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart + cl) } else { emitResponseStart(defaultEntity(cth, contentLength)) @@ -98,9 +102,10 @@ private[http] class HttpResponseParser(_settings: ParserSettings, } case None ⇒ emitResponseStart { entityParts ⇒ - val data = entityParts.collect { case ParserOutput.EntityPart(bytes) ⇒ bytes } + val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes } HttpEntity.CloseDelimited(contentType(cth), data) } + setCompletionHandling(HttpMessageParser.CompletionOk) parseToCloseBody(input, bodyStart) } @@ -110,9 +115,9 @@ private[http] class HttpResponseParser(_settings: ParserSettings, if (clh.isEmpty) { emitResponseStart(chunkedEntity(cth), completedHeaders) parseChunk(input, bodyStart, closeAfterResponseCompletion) - } else fail("A chunked response must not contain a Content-Length header.") - } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, hostHeaderPresent, - closeAfterResponseCompletion) + } else failMessageStart("A chunked response must not contain a Content-Length header.") + } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, + expect100continue, hostHeaderPresent, closeAfterResponseCompletion) } } else finishEmptyResponse() } @@ -120,7 +125,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, // currently we do not check for `settings.maxContentLength` overflow def parseToCloseBody(input: ByteString, bodyStart: Int): StateResult = { if (input.length > bodyStart) - emit(ParserOutput.EntityPart(input drop bodyStart)) + emit(EntityPart(input drop bodyStart)) continue(parseToCloseBody) } } diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala index aeec8b25c4..07daa875f9 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala @@ -28,6 +28,7 @@ private[http] object ParserOutput { protocol: HttpProtocol, headers: List[HttpHeader], createEntity: Source[RequestOutput] ⇒ RequestEntity, + expect100ContinueResponsePending: Boolean, closeAfterResponseCompletion: Boolean) extends MessageStart with RequestOutput final case class ResponseStart( @@ -43,5 +44,7 @@ private[http] object ParserOutput { final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput - final case class ParseError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput + final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput + + final case class EntityStreamError(info: ErrorInfo) extends MessageOutput } diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala index 1d9fe5875e..df3a23b9f3 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala @@ -98,23 +98,21 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.` r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } - def renderContentLength(contentLength: Long): Unit = { - if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf - r ~~ CrLf - } + def renderContentLength(contentLength: Long) = + if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r def completeRequestRendering(): Source[ByteString] = entity match { case x if x.isKnownEmpty ⇒ - renderContentLength(0) + renderContentLength(0) ~~ CrLf Source.singleton(r.get) case HttpEntity.Strict(_, data) ⇒ - renderContentLength(data.length) + renderContentLength(data.length) ~~ CrLf Source.singleton(r.get ++ data) case HttpEntity.Default(_, contentLength, data) ⇒ - renderContentLength(contentLength) + renderContentLength(contentLength) ~~ CrLf renderByteStrings(r, data.transform("checkContentLength", () ⇒ new CheckContentLengthTransformer(contentLength))) diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala index fa8b575b5c..bbeb1b07c5 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala @@ -136,6 +136,9 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } + def renderContentLengthHeader(contentLength: Long) = + if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r + def byteStrings(entityBytes: ⇒ Source[ByteString]): Source[ByteString] = renderByteStrings(r, entityBytes, skipEntity = noEntity) @@ -144,20 +147,19 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser case HttpEntity.Strict(_, data) ⇒ renderHeaders(headers.toList) renderEntityContentType(r, entity) - r ~~ `Content-Length` ~~ data.length ~~ CrLf ~~ CrLf + renderContentLengthHeader(data.length) ~~ CrLf val entityBytes = if (noEntity) ByteString.empty else data Source.singleton(r.get ++ entityBytes) case HttpEntity.Default(_, contentLength, data) ⇒ renderHeaders(headers.toList) renderEntityContentType(r, entity) - r ~~ `Content-Length` ~~ contentLength ~~ CrLf ~~ CrLf + renderContentLengthHeader(contentLength) ~~ CrLf byteStrings(data.transform("checkContentLength", () ⇒ new CheckContentLengthTransformer(contentLength))) case HttpEntity.CloseDelimited(_, data) ⇒ renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD) - renderEntityContentType(r, entity) - r ~~ CrLf + renderEntityContentType(r, entity) ~~ CrLf byteStrings(data) case HttpEntity.Chunked(contentType, chunks) ⇒ @@ -165,8 +167,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser completeResponseRendering(HttpEntity.CloseDelimited(contentType, chunks.map(_.data))) else { renderHeaders(headers.toList) - renderEntityContentType(r, entity) - r ~~ CrLf + renderEntityContentType(r, entity) ~~ CrLf byteStrings(chunks.transform("renderChunks", () ⇒ new ChunkTransformer)) } } diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala index 775cd3ebfa..ff9f268a0d 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala @@ -40,9 +40,9 @@ private object RenderSupport { } } - def renderEntityContentType(r: Rendering, entity: HttpEntity): Unit = - if (entity.contentType != ContentTypes.NoContentType) - r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf + def renderEntityContentType(r: Rendering, entity: HttpEntity) = + if (entity.contentType != ContentTypes.NoContentType) r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf + else r def renderByteStrings(r: ByteStringRendering, entityBytes: ⇒ Source[ByteString], skipEntity: Boolean = false): Source[ByteString] = { diff --git a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala index 0383635da9..b4c61292b0 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala @@ -4,30 +4,39 @@ package akka.http.engine.server +import akka.actor.{ Props, ActorRef } import akka.event.LoggingAdapter +import akka.stream.stage.PushPullStage +import akka.util.ByteString import akka.stream.io.StreamTcp import akka.stream.FlattenStrategy import akka.stream.FlowMaterializer import akka.stream.scaladsl._ -import akka.stream.stage._ -import akka.http.engine.parsing.HttpRequestParser +import akka.http.engine.parsing.{ HttpHeaderParser, HttpRequestParser } import akka.http.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory } -import akka.http.model.{ StatusCode, ErrorInfo, HttpRequest, HttpResponse, HttpMethods } +import akka.http.model._ import akka.http.engine.parsing.ParserOutput._ import akka.http.Http import akka.http.util._ -import akka.util.ByteString + +import scala.util.control.NonFatal /** * INTERNAL API */ -private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAdapter)(implicit fm: FlowMaterializer) +private[http] class HttpServerPipeline(settings: ServerSettings, + log: LoggingAdapter)(implicit fm: FlowMaterializer) extends (StreamTcp.IncomingTcpConnection ⇒ Http.IncomingConnection) { + import settings.parserSettings - val rootParser = new HttpRequestParser(settings.parserSettings, settings.rawRequestUriHeader)() - val warnOnIllegalHeader: ErrorInfo ⇒ Unit = errorInfo ⇒ - if (settings.parserSettings.illegalHeaderWarnings) - log.warning(errorInfo.withSummaryPrepended("Illegal request header").formatPretty) + // the initial header parser we initially use for every connection, + // will not be mutated, all "shared copy" parsers copy on first-write into the header cache + val rootParser = new HttpRequestParser( + parserSettings, + settings.rawRequestUriHeader, + HttpHeaderParser(parserSettings) { errorInfo ⇒ + if (parserSettings.illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal request header").formatPretty) + }) val responseRendererFactory = new HttpResponseRendererFactory(settings.serverHeader, settings.responseHeaderSizeHint, log) @@ -40,88 +49,125 @@ private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAda val userIn = Sink.publisher[HttpRequest] val userOut = Source.subscriber[HttpResponse] - val pipeline = FlowGraph { implicit b ⇒ - val bypassFanout = Broadcast[(RequestOutput, Source[RequestOutput])]("bypassFanout") - val bypassFanin = Merge[Any]("merge") + val oneHundredContinueSource = Source[OneHundredContinue.type](Props[OneHundredContinueSourceActor]) + @volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168 - val rootParsePipeline = - Flow[ByteString] - .transform("rootParser", () ⇒ rootParser.copyWith(warnOnIllegalHeader)) + val pipeline = FlowGraph { implicit b ⇒ + val bypassFanout = Broadcast[RequestOutput]("bypassFanout") + val bypassMerge = new BypassMerge + + val requestParsing = Flow[ByteString].transform("rootParser", () ⇒ + // each connection uses a single (private) request parser instance for all its requests + // which builds a cache of all header instances seen on that connection + rootParser.createShallowCopy(() ⇒ oneHundredContinueRef)) + + val requestPreparation = + Flow[RequestOutput] .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .headAndTail + .collect { + case (RequestStart(method, uri, protocol, headers, createEntity, _, _), entityParts) ⇒ + val effectiveUri = HttpRequest.effectiveUri(uri, headers, securedConnection = false, settings.defaultHostHeader) + val effectiveMethod = if (method == HttpMethods.HEAD && settings.transparentHeadRequests) HttpMethods.GET else method + HttpRequest(effectiveMethod, effectiveUri, headers, createEntity(entityParts), protocol) + } val rendererPipeline = - Flow[Any] - .transform("applyApplicationBypass", () ⇒ applyApplicationBypass) + Flow[ResponseRenderingContext] + .transform("recover", () ⇒ new ErrorsTo500ResponseRecovery(log)) // FIXME: simplify after #16394 is closed .transform("renderer", () ⇒ responseRendererFactory.newRenderer) .flatten(FlattenStrategy.concat) .transform("errorLogger", () ⇒ errorLogger(log, "Outgoing response stream error")) - val requestTweaking = Flow[(RequestOutput, Source[RequestOutput])].collect { - case (RequestStart(method, uri, protocol, headers, createEntity, _), entityParts) ⇒ - val effectiveUri = HttpRequest.effectiveUri(uri, headers, securedConnection = false, settings.defaultHostHeader) - val effectiveMethod = if (method == HttpMethods.HEAD && settings.transparentHeadRequests) HttpMethods.GET else method - HttpRequest(effectiveMethod, effectiveUri, headers, createEntity(entityParts), protocol) - } - - val bypass = - Flow[(RequestOutput, Source[RequestOutput])] - .collect[MessageStart with RequestOutput] { case (x: MessageStart, _) ⇒ x } - //FIXME: the graph is unnecessary after fixing #15957 - networkIn ~> rootParsePipeline ~> bypassFanout ~> requestTweaking ~> userIn - bypassFanout ~> bypass ~> bypassFanin - userOut ~> bypassFanin ~> rendererPipeline ~> networkOut + networkIn ~> requestParsing ~> bypassFanout ~> requestPreparation ~> userIn + bypassFanout ~> bypassMerge.bypassInput + userOut ~> bypassMerge.applicationInput ~> rendererPipeline ~> networkOut + oneHundredContinueSource ~> bypassMerge.oneHundredContinueInput }.run() + oneHundredContinueRef = Some(pipeline.get(oneHundredContinueSource)) Http.IncomingConnection(tcpConn.remoteAddress, pipeline.get(userIn), pipeline.get(userOut)) } - /** - * Combines the HttpResponse coming in from the application with the ParserOutput.RequestStart - * produced by the request parser into a ResponseRenderingContext. - * If the parser produced a ParserOutput.ParseError the error response is immediately dispatched to downstream. - */ - def applyApplicationBypass = - new PushStage[Any, ResponseRenderingContext] { - var applicationResponse: HttpResponse = _ - var requestStart: RequestStart = _ + class BypassMerge extends FlexiMerge[ResponseRenderingContext]("BypassMerge") { + import FlexiMerge._ + val bypassInput = createInputPort[RequestOutput]() + val oneHundredContinueInput = createInputPort[OneHundredContinue.type]() + val applicationInput = createInputPort[HttpResponse]() - override def onPush(elem: Any, ctx: Context[ResponseRenderingContext]): Directive = elem match { - case response: HttpResponse ⇒ - requestStart match { - case null ⇒ - applicationResponse = response - ctx.pull() - case x: RequestStart ⇒ - requestStart = null - ctx.push(dispatch(x, response)) - } - - case requestStart: RequestStart ⇒ - applicationResponse match { - case null ⇒ - this.requestStart = requestStart - ctx.pull() - case response ⇒ - applicationResponse = null - ctx.push(dispatch(requestStart, response)) - } - - case ParseError(status, info) ⇒ - ctx.push(errorResponse(status, info)) + def createMergeLogic() = new MergeLogic[ResponseRenderingContext] { + override def inputHandles(inputCount: Int) = { + require(inputCount == 3, s"BypassMerge must have 3 connected inputs, was $inputCount") + Vector(bypassInput, oneHundredContinueInput, applicationInput) } - def dispatch(requestStart: RequestStart, response: HttpResponse): ResponseRenderingContext = { - import requestStart._ - ResponseRenderingContext(response, method, protocol, closeAfterResponseCompletion) + override val initialState = State[Any](Read(bypassInput)) { + case (ctx, _, requestStart: RequestStart) ⇒ waitingForApplicationResponse(requestStart) + case (ctx, _, MessageStartError(status, info)) ⇒ finishWithError(ctx, "request", status, info) + case _ ⇒ SameState // drop other parser output } - def errorResponse(status: StatusCode, info: ErrorInfo): ResponseRenderingContext = { - log.warning("Illegal request, responding with status '{}': {}", status, info.formatPretty) + def waitingForApplicationResponse(requestStart: RequestStart): State[Any] = + State[Any](ReadAny(oneHundredContinueInput, applicationInput)) { + case (ctx, _, response: HttpResponse) ⇒ + // see the comment on [[OneHundredContinue]] for an explanation of the closing logic here (and more) + val close = requestStart.closeAfterResponseCompletion || requestStart.expect100ContinueResponsePending + ctx.emit(ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close)) + if (close) finish(ctx) else initialState + + case (ctx, _, OneHundredContinue) ⇒ + assert(requestStart.expect100ContinueResponsePending) + ctx.emit(ResponseRenderingContext(HttpResponse(StatusCodes.Continue))) + waitingForApplicationResponse(requestStart.copy(expect100ContinueResponsePending = false)) + } + + override def initialCompletionHandling = CompletionHandling( + onComplete = (ctx, _) ⇒ { ctx.complete(); SameState }, + onError = { + case (ctx, _, error: Http.StreamException) ⇒ + // the application has forwarded a request entity stream error to the response stream + finishWithError(ctx, "request", StatusCodes.BadRequest, error.info) + case (ctx, _, error) ⇒ + ctx.error(error) + SameState + }) + + def finishWithError(ctx: MergeLogicContext, target: String, status: StatusCode, info: ErrorInfo): State[Any] = { + log.warning("Illegal {}, responding with status '{}': {}", target, status, info.formatPretty) val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary - ResponseRenderingContext(HttpResponse(status, entity = msg), closeAfterResponseCompletion = true) + ctx.emit(ResponseRenderingContext(HttpResponse(status, entity = msg), closeAfterResponseCompletion = true)) + finish(ctx) + } + + def finish(ctx: MergeLogicContext): State[Any] = { + ctx.complete() // shouldn't this return a `State` rather than `Unit`? + SameState // it seems weird to stay in the same state after completion } } + } } + +private[server] class ErrorsTo500ResponseRecovery(log: LoggingAdapter) + extends PushPullStage[ResponseRenderingContext, ResponseRenderingContext] { + import akka.stream.stage.Context + + private[this] var errorResponse: ResponseRenderingContext = _ + + override def onPush(elem: ResponseRenderingContext, ctx: Context[ResponseRenderingContext]) = ctx.push(elem) + + override def onPull(ctx: Context[ResponseRenderingContext]) = + if (ctx.isFinishing) ctx.pushAndFinish(errorResponse) + else ctx.pull() + + override def onUpstreamFailure(error: Throwable, ctx: Context[ResponseRenderingContext]) = + error match { + case NonFatal(e) ⇒ + log.error(e, "Internal server error, sending 500 response") + errorResponse = ResponseRenderingContext(HttpResponse(StatusCodes.InternalServerError), + closeAfterResponseCompletion = true) + ctx.absorbTermination() + case _ ⇒ ctx.fail(error) + } +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/server/OneHundredContinue.scala b/akka-http-core/src/main/scala/akka/http/engine/server/OneHundredContinue.scala new file mode 100644 index 0000000000..37c58cad4f --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/server/OneHundredContinue.scala @@ -0,0 +1,66 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.engine.server + +import scala.annotation.tailrec +import akka.stream.actor.{ ActorPublisherMessage, ActorPublisher } + +/** + * The `Expect: 100-continue` header has a special status in HTTP. + * It allows the client to send an `Expect: 100-continue` header with the request and then pause request sending + * (i.e. hold back sending the request entity). The server reads the request headers, determines whether it wants to + * accept the request and responds with + * + * - `417 Expectation Failed`, if it doesn't support the `100-continue` expectation + * (or if the `Expect` header contains other, unsupported expectations). + * - a `100 Continue` response, + * if it is ready to accept the request entity and the client should go ahead with sending it + * - a final response (like a 4xx to signal some client-side error + * (e.g. if the request entity length is beyond the configured limit) or a 3xx redirect) + * + * Only if the client receives a `100 Continue` response from the server is it allowed to continue sending the request + * entity. In this case it will receive another response after having completed request sending. + * So this special feature breaks the normal "one request - one response" logic of HTTP! + * It therefore requires special handling in all HTTP stacks (client- and server-side). + * + * For us this means: + * + * - on the server-side: + * After having read a `Expect: 100-continue` header with the request we package up an `HttpRequest` instance and send + * it through to the application. Only when (and if) the application then requests data from the entity stream do we + * send out a `100 Continue` response and continue reading the request entity. + * The application can therefore determine itself whether it wants the client to send the request entity + * by deciding whether to look at the request entity data stream or not. + * If the application sends a response *without* having looked at the request entity the client receives this + * response *instead of* the `100 Continue` response and the server closes the connection afterwards. + * + * - on the client-side: + * If the user adds a `Expect: 100-continue` header to the request we need to hold back sending the entity until + * we've received a `100 Continue` response. + */ +private[engine] case object OneHundredContinue + +private[engine] class OneHundredContinueSourceActor extends ActorPublisher[OneHundredContinue.type] { + private var triggered = 0 + + def receive = { + case OneHundredContinue ⇒ + triggered += 1 + tryDispatch() + + case ActorPublisherMessage.Request(_) ⇒ + tryDispatch() + + case ActorPublisherMessage.Cancel ⇒ + context.stop(self) + } + + @tailrec private def tryDispatch(): Unit = + if (triggered > 0 && totalDemand > 0) { + onNext(OneHundredContinue) + triggered -= 1 + tryDispatch() + } +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/model/HttpMessage.scala b/akka-http-core/src/main/scala/akka/http/model/HttpMessage.scala index 8e1e8fbd17..50cd953837 100644 --- a/akka-http-core/src/main/scala/akka/http/model/HttpMessage.scala +++ b/akka-http-core/src/main/scala/akka/http/model/HttpMessage.scala @@ -5,6 +5,8 @@ package akka.http.model import java.lang.{ Iterable ⇒ JIterable } +import akka.parboiled2.CharUtils + import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ Future, ExecutionContext } import scala.collection.immutable @@ -129,9 +131,9 @@ final case class HttpRequest(method: HttpMethod = HttpMethods.GET, headers: immutable.Seq[HttpHeader] = Nil, entity: RequestEntity = HttpEntity.Empty, protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpRequest with HttpMessage { - require(!uri.isEmpty, "An HttpRequest must not have an empty Uri") + HttpRequest.verifyUri(uri) require(entity.isKnownEmpty || method.isEntityAccepted, "Requests with this method must have an empty entity") - require(protocol == HttpProtocols.`HTTP/1.1` || !entity.isInstanceOf[HttpEntity.Chunked], + require(protocol != HttpProtocols.`HTTP/1.0` || !entity.isInstanceOf[HttpEntity.Chunked], "HTTP/1.0 requests must not have a chunked entity") type Self = HttpRequest @@ -281,6 +283,22 @@ object HttpRequest { else throw new IllegalUriException(s"'Host' header value of request to `$uri` doesn't match request target authority", s"Host header: $hostHeader\nrequest target authority: ${uri.authority}") } + + /** + * Verifies that the given [[Uri]] is non-empty and has either scheme `http`, `https` or no scheme at all. + * If any of these conditions is not met the method throws an [[IllegalArgumentException]]. + */ + def verifyUri(uri: Uri): Unit = + if (uri.isEmpty) throw new IllegalArgumentException("`uri` must not be empty") + else { + def c(i: Int) = CharUtils.toLowerCase(uri.scheme charAt i) + uri.scheme.length match { + case 0 ⇒ // ok + case 4 if c(0) == 'h' && c(1) == 't' && c(2) == 't' && c(3) == 'p' ⇒ // ok + case 5 if c(0) == 'h' && c(1) == 't' && c(2) == 't' && c(3) == 'p' && c(4) == 's' ⇒ // ok + case _ ⇒ throw new IllegalArgumentException("""`uri` must have scheme "http", "https" or no scheme""") + } + } } /** @@ -290,6 +308,10 @@ final case class HttpResponse(status: StatusCode = StatusCodes.OK, headers: immutable.Seq[HttpHeader] = Nil, entity: ResponseEntity = HttpEntity.Empty, protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpResponse with HttpMessage { + require(entity.isKnownEmpty || status.allowsEntity, "Responses with this status code must have an empty entity") + require(protocol == HttpProtocols.`HTTP/1.1` || !entity.isInstanceOf[HttpEntity.Chunked], + "HTTP/1.0 responses must not have a chunked entity") + type Self = HttpResponse def self = this diff --git a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala index dc7ddff9dc..96cbc0ae58 100644 --- a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala @@ -8,8 +8,6 @@ package headers import java.lang.Iterable import java.net.InetSocketAddress import java.util -import akka.http.model.japi -import akka.http.model.japi.JavaMapping import scala.annotation.tailrec import scala.collection.immutable import akka.http.util._ @@ -54,6 +52,7 @@ final case class Connection(tokens: immutable.Seq[String]) extends ModeledHeader def renderValue[R <: Rendering](r: R): r.type = r ~~ tokens def hasClose = has("close") def hasKeepAlive = has("keep-alive") + def append(tokens: immutable.Seq[String]) = Connection(this.tokens ++ tokens) @tailrec private def has(item: String, ix: Int = 0): Boolean = if (ix < tokens.length) if (tokens(ix) equalsIgnoreCase item) true @@ -563,6 +562,7 @@ final case class `Transfer-Encoding`(encodings: immutable.Seq[TransferEncoding]) case remaining ⇒ Some(`Transfer-Encoding`(remaining)) } } else Some(this) + def append(encodings: immutable.Seq[TransferEncoding]) = `Transfer-Encoding`(this.encodings ++ encodings) def renderValue[R <: Rendering](r: R): r.type = r ~~ encodings protected def companion = `Transfer-Encoding` diff --git a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala index 9a668b0f1d..177f637e2b 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala @@ -22,7 +22,7 @@ import HttpMethods._ import HttpProtocols._ import StatusCodes._ import HttpEntity._ -import ParserOutput.ParseError +import ParserOutput._ import FastFuture._ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { @@ -48,13 +48,14 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } "with no headers and no body but remaining content" in new Test { - """GET / HTTP/1.0 + Seq("""GET / HTTP/1.0 | |POST /foo HTTP/1.0 | - |TRA""" /* beginning of TRACE request */ should parseTo( - HttpRequest(GET, "/", protocol = `HTTP/1.0`), - HttpRequest(POST, "/foo", protocol = `HTTP/1.0`)) + |TRA""") /* beginning of TRACE request */ should generalMultiParseTo( + Right(HttpRequest(GET, "/", protocol = `HTTP/1.0`)), + Right(HttpRequest(POST, "/foo", protocol = `HTTP/1.0`)), + Left(MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start")))) closeAfterResponseCompletion shouldEqual Seq(true, true) } @@ -168,7 +169,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "request start" in new Test { Seq(start, "rest") should generalMultiParseTo( Right(baseRequest.withEntity(HttpEntity.Chunked(`application/pdf`, source()))), - Left(ParseError(400: StatusCode, ErrorInfo("Illegal character 'r' in chunk start")))) + Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -182,15 +183,18 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "10;foo=", """bar |0123456789ABCDEF - |10 + |A |0123456789""", - """ABCDEF - |dead""") should generalMultiParseTo( + """ + |0 + | + |""") should generalMultiParseTo( Right(baseRequest.withEntity(Chunked(`application/pdf`, source( Chunk(ByteString("abc")), Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), - Chunk(ByteString("0123456789ABCDEF"), "")))))) + Chunk(ByteString("0123456789"), ""), + LastChunk))))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -203,14 +207,14 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { closeAfterResponseCompletion shouldEqual Seq(false) } - "message end with extension, trailer and remaining content" in new Test { + "message end with extension and trailer" in new Test { Seq(start, """000;nice=true |Foo: pip | apo |Bar: xyz | - |GE""") should generalMultiParseTo( + |""") should generalMultiParseTo( Right(baseRequest.withEntity(Chunked(`application/pdf`, source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) closeAfterResponseCompletion shouldEqual Seq(false) @@ -238,13 +242,16 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { |Content-Type: application/pdf |Host: ping | + |0 + | |""" should parseTo(HttpRequest(PATCH, "/data", List(`Transfer-Encoding`(TransferEncodings.Extension("fancy")), - Host("ping")), HttpEntity.Chunked(`application/pdf`, source()))) + Host("ping")), HttpEntity.Chunked(`application/pdf`, source(LastChunk)))) closeAfterResponseCompletion shouldEqual Seq(false) } "support `rawRequestUriHeader` setting" in new Test { - override protected def newParser: HttpRequestParser = new HttpRequestParser(parserSettings, rawRequestUriHeader = true)() + override protected def newParser: HttpRequestParser = + new HttpRequestParser(parserSettings, rawRequestUriHeader = true, _headerParser = HttpHeaderParser(parserSettings)()) """GET /f%6f%6fbar?q=b%61z HTTP/1.1 |Host: ping @@ -275,19 +282,19 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { Seq(start, """15 ; |""") should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("Illegal character ' ' in chunk start")))) + Left(EntityStreamError(ErrorInfo("Illegal character ' ' in chunk start")))) closeAfterResponseCompletion shouldEqual Seq(false) } "an illegal char in chunk size" in new Test { Seq(start, "bla") should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("Illegal character 'l' in chunk start")))) + Left(EntityStreamError(ErrorInfo("Illegal character 'l' in chunk start")))) closeAfterResponseCompletion shouldEqual Seq(false) } "too-long chunk extension" in new Test { Seq(start, "3;" + ("x" * 257)) should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters")))) + Left(EntityStreamError(ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -295,7 +302,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { Seq(start, """1a2b3c4d5e |""") should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes")))) + Left(EntityStreamError(ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -303,7 +310,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { Seq(start, """3 |abcde""") should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("Illegal chunk termination")))) + Left(EntityStreamError(ErrorInfo("Illegal chunk termination")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -311,7 +318,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { Seq(start, """0 |F@oo: pip""") should generalMultiParseTo(Right(baseRequest), - Left(ParseError(400: StatusCode, ErrorInfo("Illegal character '@' in header name")))) + Left(EntityStreamError(ErrorInfo("Illegal character '@' in header name")))) closeAfterResponseCompletion shouldEqual Seq(false) } } @@ -333,7 +340,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "two Content-Length headers" in new Test { """GET / HTTP/1.1 |Content-Length: 3 - |Content-Length: 3 + |Content-Length: 4 | |foo""" should parseToError(BadRequest, ErrorInfo("HTTP message must not contain more than one Content-Length header")) @@ -403,9 +410,8 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { override def toString = req.toString } - def strictEqualify(x: Either[ParseError, HttpRequest]): Either[ParseError, StrictEqualHttpRequest] = { + def strictEqualify[T](x: Either[T, HttpRequest]): Either[T, StrictEqualHttpRequest] = x.right.map(new StrictEqualHttpRequest(_)) - } def parseTo(expected: HttpRequest*): Matcher[String] = multiParseTo(expected: _*).compose(_ :: Nil) @@ -420,35 +426,35 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { generalRawMultiParseTo(parser, expected.map(Right(_)): _*) def parseToError(status: StatusCode, info: ErrorInfo): Matcher[String] = - generalMultiParseTo(Left(ParseError(status, info))).compose(_ :: Nil) + generalMultiParseTo(Left(MessageStartError(status, info))).compose(_ :: Nil) - def generalMultiParseTo(expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = + def generalMultiParseTo(expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] = generalRawMultiParseTo(expected: _*).compose(_ map prep) - def generalRawMultiParseTo(expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = + def generalRawMultiParseTo(expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] = generalRawMultiParseTo(newParser, expected: _*) def generalRawMultiParseTo(parser: HttpRequestParser, - expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = + expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] = equal(expected.map(strictEqualify)) - .matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser) + .matcher[Seq[Either[RequestOutput, StrictEqualHttpRequest]]] compose multiParse(parser) - def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] = + def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] = Source(input.toList) .map(ByteString.apply) .transform("parser", () ⇒ parser) - .splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) + .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .headAndTail .collect { - case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) ⇒ + case (RequestStart(method, uri, protocol, headers, createEntity, _, close), entityParts) ⇒ closeAfterResponseCompletion :+= close Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) - case (x: ParseError, _) ⇒ Left(x) + case (x @ (MessageStartError(_, _) | EntityStreamError(_)), _) ⇒ Left(x) } .map { x ⇒ Source { x match { case Right(request) ⇒ compactEntity(request.entity).fast.map(x ⇒ Right(request.withEntity(x))) - case Left(error) ⇒ Future.successful(Left(error)) + case Left(error) ⇒ FastFuture.successful(Left(error)) } } } @@ -458,7 +464,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { .awaitResult(awaitAtMost) protected def parserSettings: ParserSettings = ParserSettings(system) - protected def newParser = new HttpRequestParser(parserSettings, false)() + protected def newParser = new HttpRequestParser(parserSettings, false, HttpHeaderParser(parserSettings)()) private def compactEntity(entity: RequestEntity): Future[RequestEntity] = entity match { diff --git a/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala index 76da3481ec..58569af108 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala @@ -22,7 +22,7 @@ import HttpMethods._ import HttpProtocols._ import StatusCodes._ import HttpEntity._ -import ParserOutput.ParseError +import ParserOutput._ import FastFuture._ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { @@ -43,7 +43,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "a 200 response to a HEAD request" in new Test { """HTTP/1.1 200 OK | - |HTT""" should parseTo(HEAD, HttpResponse()) + |""" should parseTo(HEAD, HttpResponse()) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -97,14 +97,15 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } "a response with 3 headers, a body and remaining content" in new Test { - """HTTP/1.1 500 Internal Server Error + Seq("""HTTP/1.1 500 Internal Server Error |User-Agent: curl/7.19.7 xyz |Connection:close |Content-Length: 17 |Content-Type: text/plain; charset=UTF-8 | - |Shake your BOODY!HTTP/1.""" should parseTo(HttpResponse(InternalServerError, List(Connection("close"), - `User-Agent`("curl/7.19.7 xyz")), "Shake your BOODY!")) + |Sh""", "ake your BOODY!HTTP/1.") should generalMultiParseTo( + Right(HttpResponse(InternalServerError, List(Connection("close"), `User-Agent`("curl/7.19.7 xyz")), + "Shake your BOODY!"))) closeAfterResponseCompletion shouldEqual Seq(true) } @@ -133,7 +134,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { "response start" in new Test { Seq(start, "rest") should generalMultiParseTo( Right(baseResponse.withEntity(Chunked(`application/pdf`, source()))), - Left("Illegal character 'r' in chunk start")) + Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -150,12 +151,14 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { |10 |0123456789""", """ABCDEF - |dead""") should generalMultiParseTo( + |0 + | + |""") should generalMultiParseTo( Right(baseResponse.withEntity(Chunked(`application/pdf`, source( Chunk(ByteString("abc")), Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), - Chunk(ByteString("0123456789ABCDEF"), "")))))) + Chunk(ByteString("0123456789ABCDEF")), LastChunk))))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -177,33 +180,38 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { | |HT""") should generalMultiParseTo( Right(baseResponse.withEntity(Chunked(`application/pdf`, - source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) + source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))), + Left(MessageStartError(400: StatusCode, ErrorInfo("Illegal HTTP message start")))) closeAfterResponseCompletion shouldEqual Seq(false) } "response with additional transfer encodings" in new Test { - """HTTP/1.1 200 OK + Seq("""HTTP/1.1 200 OK |Transfer-Encoding: fancy, chunked - |Content-Type: application/pdf + |Cont""", """ent-Type: application/pdf | - |""" should parseTo(HttpResponse(headers = List(`Transfer-Encoding`(TransferEncodings.Extension("fancy"))), - entity = HttpEntity.Chunked(`application/pdf`, source()))) + |""") should generalMultiParseTo( + Right(HttpResponse(headers = List(`Transfer-Encoding`(TransferEncodings.Extension("fancy"))), + entity = HttpEntity.Chunked(`application/pdf`, source()))), + Left(EntityStreamError(ErrorInfo("Entity stream truncation")))) closeAfterResponseCompletion shouldEqual Seq(false) } } "reject a response with" - { "HTTP version 1.2" in new Test { - Seq("HTTP/1.2 200 OK\r\n") should generalMultiParseTo(Left("The server-side HTTP version is not supported")) + Seq("HTTP/1.2 200 OK\r\n") should generalMultiParseTo(Left(MessageStartError( + 400: StatusCode, ErrorInfo("The server-side HTTP version is not supported")))) } "an illegal status code" in new Test { - Seq("HTTP/1", ".1 2000 Something") should generalMultiParseTo(Left("Illegal response status code")) + Seq("HTTP/1", ".1 2000 Something") should generalMultiParseTo(Left(MessageStartError( + 400: StatusCode, ErrorInfo("Illegal response status code")))) } "a too-long response status reason" in new Test { - Seq("HTTP/1.1 204 12345678", "90123456789012\r\n") should generalMultiParseTo( - Left("Response reason phrase exceeds the configured limit of 21 characters")) + Seq("HTTP/1.1 204 12345678", "90123456789012\r\n") should generalMultiParseTo(Left( + MessageStartError(400: StatusCode, ErrorInfo("Response reason phrase exceeds the configured limit of 21 characters")))) } } } @@ -224,9 +232,8 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { override def toString = resp.toString } - def strictEqualify(x: Either[String, HttpResponse]): Either[String, StrictEqualHttpResponse] = { + def strictEqualify[T](x: Either[T, HttpResponse]): Either[T, StrictEqualHttpResponse] = x.right.map(new StrictEqualHttpResponse(_)) - } def parseTo(expected: HttpResponse*): Matcher[String] = parseTo(GET, expected: _*) def parseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[String] = @@ -240,46 +247,44 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def rawMultiParseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[Seq[String]] = generalRawMultiParseTo(requestMethod, expected.map(Right(_)): _*) - def parseToError(error: String): Matcher[String] = generalMultiParseTo(Left(error)).compose(_ :: Nil) + def parseToError(error: ResponseOutput): Matcher[String] = generalMultiParseTo(Left(error)).compose(_ :: Nil) - def generalMultiParseTo(expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = + def generalMultiParseTo(expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] = generalRawMultiParseTo(expected: _*).compose(_ map prep) - def generalRawMultiParseTo(expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = + def generalRawMultiParseTo(expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] = generalRawMultiParseTo(GET, expected: _*) - def generalRawMultiParseTo(requestMethod: HttpMethod, expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = + def generalRawMultiParseTo(requestMethod: HttpMethod, expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] = equal(expected.map(strictEqualify)) - .matcher[Seq[Either[String, StrictEqualHttpResponse]]] compose { - input: Seq[String] ⇒ - val future = - Source(input.toList) - .map(ByteString.apply) - .transform("parser", () ⇒ newParser(requestMethod)) - .splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) - .headAndTail - .collect { - case (ParserOutput.ResponseStart(statusCode, protocol, headers, createEntity, close), entityParts) ⇒ - closeAfterResponseCompletion :+= close - Right(HttpResponse(statusCode, headers, createEntity(entityParts), protocol)) - case (x: ParseError, _) ⇒ Left(x) - }.map { x ⇒ - Source { - x match { - case Right(response) ⇒ compactEntity(response.entity).fast.map(x ⇒ Right(response.withEntity(x))) - case Left(error) ⇒ FastFuture.successful(Left(error.info.formatPretty)) - } + .matcher[Seq[Either[ResponseOutput, StrictEqualHttpResponse]]] compose { input: Seq[String] ⇒ + val future = + Source(input.toList) + .map(ByteString.apply) + .transform("parser", () ⇒ newParser(requestMethod)) + .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) + .headAndTail + .collect { + case (ResponseStart(statusCode, protocol, headers, createEntity, close), entityParts) ⇒ + closeAfterResponseCompletion :+= close + Right(HttpResponse(statusCode, headers, createEntity(entityParts), protocol)) + case (x @ (MessageStartError(_, _) | EntityStreamError(_)), _) ⇒ Left(x) + }.map { x ⇒ + Source { + x match { + case Right(response) ⇒ compactEntity(response.entity).fast.map(x ⇒ Right(response.withEntity(x))) + case Left(error) ⇒ FastFuture.successful(Left(error)) } } - .flatten(FlattenStrategy.concat) - .map(strictEqualify) - .grouped(1000).runWith(Sink.head) - Await.result(future, 500.millis) + } + .flatten(FlattenStrategy.concat) + .map(strictEqualify) + .grouped(1000).runWith(Sink.head) + Await.result(future, 500.millis) } def parserSettings: ParserSettings = ParserSettings(system) def newParser(requestMethod: HttpMethod = GET) = { - val parser = new HttpResponseParser(parserSettings, - dequeueRequestMethodForNextResponse = () ⇒ requestMethod)() + val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)(), () ⇒ requestMethod) parser } diff --git a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala index 84cbb078f8..4269310194 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala @@ -4,7 +4,6 @@ package akka.http.engine.rendering -import akka.http.model.HttpMethods._ import com.typesafe.config.{ Config, ConfigFactory } import scala.concurrent.duration._ import scala.concurrent.Await @@ -18,7 +17,6 @@ import akka.http.util._ import akka.util.ByteString import akka.stream.scaladsl._ import akka.stream.FlowMaterializer -import akka.stream.impl.SynchronousIterablePublisher import HttpEntity._ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll { @@ -51,7 +49,6 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll |Age: 0 |Server: akka-http/1.0.0 |Date: Thu, 25 Aug 2011 09:10:29 GMT - |Content-Length: 0 | |""" } diff --git a/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerPipelineSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerPipelineSpec.scala index 2b4be61c2b..43f2fa82ed 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerPipelineSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerPipelineSpec.scala @@ -5,23 +5,26 @@ package akka.http.engine.server import scala.concurrent.duration._ +import org.scalatest.{ Inside, BeforeAndAfterAll, Matchers } import akka.event.NoLogging -import akka.http.model.HttpEntity.{ Chunk, ChunkStreamPart, LastChunk } -import akka.http.model._ -import akka.http.model.headers.{ ProductVersion, Server, Host } -import akka.http.util._ -import akka.http.Http +import akka.util.ByteString import akka.stream.scaladsl._ import akka.stream.FlowMaterializer import akka.stream.io.StreamTcp import akka.stream.testkit.{ AkkaSpec, StreamTestKit } -import akka.util.ByteString -import org.scalatest._ +import akka.http.Http +import akka.http.model._ +import akka.http.util._ +import headers._ +import HttpEntity._ +import MediaTypes._ +import HttpMethods._ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside { implicit val materializer = FlowMaterializer() - "The server implementation should" should { + "The server implementation" should { + "deliver an empty request as soon as all headers are received" in new TestSetup { send("""GET / HTTP/1.1 |Host: example.com @@ -30,6 +33,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) } + "deliver a request as soon as all headers are received" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com @@ -38,7 +42,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ByteString] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -53,18 +57,26 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA dataProbe.expectNoMsg(50.millis) } } - "deliver an error as soon as a parsing error occurred" in new TestSetup { - pending - // POST should require Content-Length header - send("""POST / HTTP/1.1 + + "deliver an error response as soon as a parsing error occurred" in new TestSetup { + send("""GET / HTTP/1.2 |Host: example.com | |""".stripMarginWithNewline("\r\n")) - requests.expectError() + netOutSub.request(1) + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 505 HTTP Version Not Supported + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 74 + | + |The server does not support the HTTP protocol version used in the request.""".stripMarginWithNewline("\r\n") } - "report a invalid Chunked stream" in new TestSetup { - pending + + "report an invalid Chunked stream" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com |Transfer-Encoding: chunked @@ -74,7 +86,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -83,8 +95,23 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA dataProbe.expectNoMsg(50.millis) send("3ghi\r\n") // missing "\r\n" after the number of bytes - dataProbe.expectError() - requests.expectError() + val error = dataProbe.expectError() + error.getMessage shouldEqual "Illegal character 'g' in chunk start" + requests.expectComplete() + + netOutSub.request(1) + responsesSub.expectRequest() + responsesSub.sendError(error.asInstanceOf[Exception]) + + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 400 Bad Request + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 36 + | + |Illegal character 'g' in chunk start""".stripMarginWithNewline("\r\n") } } @@ -97,11 +124,12 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA expectRequest shouldEqual HttpRequest( - method = HttpMethods.POST, + method = POST, uri = "http://example.com/strict", headers = List(Host("example.com")), entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) } + "deliver the request entity as it comes in for a Default entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com @@ -110,7 +138,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |abcdef""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ByteString] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -122,6 +150,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA dataProbe.expectNoMsg(50.millis) } } + "deliver the request entity as it comes in for a chunked entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com @@ -132,7 +161,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -154,7 +183,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA expectRequest shouldEqual HttpRequest( - method = HttpMethods.POST, + method = POST, uri = "http://example.com/strict", headers = List(Host("example.com")), entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) @@ -167,11 +196,12 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA expectRequest shouldEqual HttpRequest( - method = HttpMethods.POST, + method = POST, uri = "http://example.com/next-strict", headers = List(Host("example.com")), entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("mnopqrstuvwx"))) } + "deliver the second message properly after a Default entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com @@ -180,7 +210,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |abcdef""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ByteString] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -202,10 +232,11 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |abcde""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) ⇒ - data shouldEqual (ByteString("abcde")) + case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ + data shouldEqual ByteString("abcde") } } + "deliver the second message properly after a Chunked entity" in new TestSetup { send("""POST /chunked HTTP/1.1 |Host: example.com @@ -216,7 +247,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -239,8 +270,8 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |abcde""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) ⇒ - data shouldEqual (ByteString("abcde")) + case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ + data shouldEqual ByteString("abcde") } } @@ -252,7 +283,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |abcdef""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ByteString] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -264,6 +295,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA dataProbe.expectComplete() } } + "close the request entity stream when the entity is complete for a Chunked entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com @@ -274,7 +306,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |""".stripMarginWithNewline("\r\n")) inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() @@ -288,27 +320,26 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA } } - "report a truncated entity stream on the entity data stream and the main stream for a Default entity" in pendingUntilFixed(new TestSetup { + "report a truncated entity stream on the entity data stream and the main stream for a Default entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com |Content-Length: 12 | |abcdef""".stripMarginWithNewline("\r\n")) - inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ByteString] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() sub.request(10) dataProbe.expectNext(ByteString("abcdef")) dataProbe.expectNoMsg(50.millis) - closeNetworkInput() - dataProbe.expectError() + dataProbe.expectError().getMessage shouldEqual "Entity stream truncation" } - }) - "report a truncated entity stream on the entity data stream and the main stream for a Chunked entity" in pendingUntilFixed(new TestSetup { + } + + "report a truncated entity stream on the entity data stream and the main stream for a Chunked entity" in new TestSetup { send("""POST / HTTP/1.1 |Host: example.com |Transfer-Encoding: chunked @@ -316,20 +347,18 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |6 |abcdef |""".stripMarginWithNewline("\r\n")) - inside(expectRequest) { - case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val sub = dataProbe.expectSubscription() sub.request(10) dataProbe.expectNext(Chunk(ByteString("abcdef"))) dataProbe.expectNoMsg(50.millis) - closeNetworkInput() - dataProbe.expectError() + dataProbe.expectError().getMessage shouldEqual "Entity stream truncation" } - }) + } "translate HEAD request to GET request when transparent-head-requests are enabled" in new TestSetup { override def settings = ServerSettings(system).copy(transparentHeadRequests = true) @@ -337,8 +366,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |Host: example.com | |""".stripMarginWithNewline("\r\n")) - - expectRequest shouldEqual HttpRequest(HttpMethods.GET, uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com"))) } "keep HEAD request when transparent-head-requests are disabled" in new TestSetup { @@ -347,21 +375,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA |Host: example.com | |""".stripMarginWithNewline("\r\n")) - - expectRequest shouldEqual HttpRequest(HttpMethods.HEAD, uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com"))) } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Strict)" in new TestSetup { - override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) send("""HEAD / HTTP/1.1 |Host: example.com | |""".stripMarginWithNewline("\r\n")) - inside(expectRequest) { - case HttpRequest(HttpMethods.GET, _, _, _, _) ⇒ + case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd")))) - netOutSub.request(1) wipeDate(netOut.expectNext().utf8String) shouldEqual """|HTTP/1.1 200 OK @@ -375,23 +399,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Default)" in new TestSetup { - override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) send("""HEAD / HTTP/1.1 |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = StreamTestKit.PublisherProbe[ByteString] - inside(expectRequest) { - case HttpRequest(HttpMethods.GET, _, _, _, _) ⇒ + case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) - netOutSub.request(1) - val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual """|HTTP/1.1 200 OK |Server: akka-http/test @@ -404,23 +422,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with CloseDelimited)" in new TestSetup { - override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) send("""HEAD / HTTP/1.1 |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = StreamTestKit.PublisherProbe[ByteString] - inside(expectRequest) { - case HttpRequest(HttpMethods.GET, _, _, _, _) ⇒ + case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) - netOutSub.request(1) - val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual """|HTTP/1.1 200 OK |Server: akka-http/test @@ -429,29 +441,22 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA | |""".stripMarginWithNewline("\r\n") } - // No close should happen here since this was a HEAD request netOut.expectNoMsg(50.millis) } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Chunked)" in new TestSetup { - override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) send("""HEAD / HTTP/1.1 |Host: example.com | |""".stripMarginWithNewline("\r\n")) - val data = StreamTestKit.PublisherProbe[ChunkStreamPart] - inside(expectRequest) { - case HttpRequest(HttpMethods.GET, _, _, _, _) ⇒ + case HttpRequest(GET, _, _, _, _) ⇒ responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) - netOutSub.request(1) - val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual """|HTTP/1.1 200 OK |Server: akka-http/test @@ -464,29 +469,146 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA } "respect Connection headers of HEAD requests if transparent-head-requests is enabled" in new TestSetup { - override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) send("""HEAD / HTTP/1.1 |Host: example.com |Connection: close | |""".stripMarginWithNewline("\r\n")) - val data = StreamTestKit.PublisherProbe[ByteString] - inside(expectRequest) { - case HttpRequest(HttpMethods.GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) - + case HttpRequest(GET, _, _, _, _) ⇒ + responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) netOutSub.request(1) - val dataSub = data.expectSubscription() dataSub.expectCancellation() - netOut.expectNext() } - netOut.expectComplete() } + + "produce a `100 Continue` response when requested by a `Default` entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Expect: 100-continue + |Content-Length: 16 + | + |""".stripMarginWithNewline("\r\n")) + inside(expectRequest) { + case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.to(Sink(dataProbe)).run() + val dataSub = dataProbe.expectSubscription() + netOutSub.request(2) + netOut.expectNoMsg(50.millis) + dataSub.request(1) // triggers `100 Continue` response + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 100 Continue + |Server: akka-http/test + |Date: XXXX + | + |""".stripMarginWithNewline("\r\n") + dataProbe.expectNoMsg(50.millis) + send("0123456789ABCDEF") + dataProbe.expectNext(ByteString("0123456789ABCDEF")) + dataProbe.expectComplete() + responsesSub.sendNext(HttpResponse(entity = "Yeah")) + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 200 OK + |Server: akka-http/test + |Date: XXXX + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 4 + | + |Yeah""".stripMarginWithNewline("\r\n") + } + } + + "produce a `100 Continue` response when requested by a `Chunked` entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Expect: 100-continue + |Transfer-Encoding: chunked + | + |""".stripMarginWithNewline("\r\n")) + inside(expectRequest) { + case HttpRequest(POST, _, _, Chunked(ContentType(`application/octet-stream`, None), data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.to(Sink(dataProbe)).run() + val dataSub = dataProbe.expectSubscription() + netOutSub.request(2) + netOut.expectNoMsg(50.millis) + dataSub.request(2) // triggers `100 Continue` response + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 100 Continue + |Server: akka-http/test + |Date: XXXX + | + |""".stripMarginWithNewline("\r\n") + dataProbe.expectNoMsg(50.millis) + send("""10 + |0123456789ABCDEF + |0 + | + |""".stripMarginWithNewline("\r\n")) + dataProbe.expectNext(Chunk(ByteString("0123456789ABCDEF"))) + dataProbe.expectNext(LastChunk) + dataProbe.expectComplete() + responsesSub.sendNext(HttpResponse(entity = "Yeah")) + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 200 OK + |Server: akka-http/test + |Date: XXXX + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 4 + | + |Yeah""".stripMarginWithNewline("\r\n") + } + } + + "render a closing response instead of `100 Continue` if request entity is not requested" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Expect: 100-continue + |Content-Length: 16 + | + |""".stripMarginWithNewline("\r\n")) + inside(expectRequest) { + case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ + netOutSub.request(1) + responsesSub.sendNext(HttpResponse(entity = "Yeah")) + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 200 OK + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 4 + | + |Yeah""".stripMarginWithNewline("\r\n") + } + } + + "render a 500 response on response stream errors from the application" in new TestSetup { + send("""GET / HTTP/1.1 + |Host: example.com + | + |""".stripMarginWithNewline("\r\n")) + + expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) + + netOutSub.request(1) + responsesSub.expectRequest() + responsesSub.sendError(new RuntimeException("CRASH BOOM BANG")) + + wipeDate(netOut.expectNext().utf8String) shouldEqual + """HTTP/1.1 500 Internal Server Error + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Length: 0 + | + |""".stripMarginWithNewline("\r\n") + } } class TestSetup { @@ -494,7 +616,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA val netOut = StreamTestKit.SubscriberProbe[ByteString] val tcpConnection = StreamTcp.IncomingTcpConnection(null, netIn, netOut) - def settings = ServerSettings(system) + def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) val pipeline = new HttpServerPipeline(settings, NoLogging) val Http.IncomingConnection(_, requestsIn, responsesOut) = pipeline(tcpConnection)