Merge pull request #16337 from spray/wip-15799-mathias

Add server-side support for `Expect: 100-continue`, other improvements
This commit is contained in:
Björn Antonsson 2014-11-27 10:56:10 +01:00
commit 5e04d11d66
19 changed files with 769 additions and 399 deletions

View file

@ -11,9 +11,9 @@ import org.reactivestreams.{ Publisher, Subscriber }
import scala.collection.immutable import scala.collection.immutable
import akka.io.Inet import akka.io.Inet
import akka.stream.MaterializerSettings import akka.stream.MaterializerSettings
import akka.http.engine.client.{ HttpClientProcessor, ClientConnectionSettings } import akka.http.engine.client.ClientConnectionSettings
import akka.http.engine.server.ServerSettings import akka.http.engine.server.ServerSettings
import akka.http.model.{ HttpResponse, HttpRequest, japi } import akka.http.model.{ ErrorInfo, HttpResponse, HttpRequest, japi }
import akka.http.util._ import akka.http.util._
import akka.actor._ import akka.actor._
@ -141,6 +141,8 @@ object Http extends ExtensionKey[HttpExt] {
class ConnectionAttemptFailedException(val endpoint: InetSocketAddress) extends ConnectionException(s"Connection attempt to $endpoint failed") class ConnectionAttemptFailedException(val endpoint: InetSocketAddress) extends ConnectionException(s"Connection attempt to $endpoint failed")
class RequestTimeoutException(val request: HttpRequest, message: String) extends ConnectionException(message) class RequestTimeoutException(val request: HttpRequest, message: String) extends ConnectionException(message)
class StreamException(val info: ErrorInfo) extends RuntimeException(info.summary)
} }
class HttpExt(system: ExtendedActorSystem) extends akka.io.IO.Extension { class HttpExt(system: ExtendedActorSystem) extends akka.io.IO.Extension {

View file

@ -5,18 +5,17 @@
package akka.http.engine.client package akka.http.engine.client
import java.net.InetSocketAddress import java.net.InetSocketAddress
import akka.util.ByteString
import scala.collection.immutable.Queue import scala.collection.immutable.Queue
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.stream.FlowMaterializer import akka.stream.FlowMaterializer
import akka.stream.FlattenStrategy import akka.stream.FlattenStrategy
import akka.stream.io.StreamTcp import akka.stream.io.StreamTcp
import akka.util.ByteString
import akka.http.Http import akka.http.Http
import akka.http.model.{ HttpMethod, HttpRequest, ErrorInfo, HttpResponse } import akka.http.model.{ HttpMethod, HttpRequest, ErrorInfo, HttpResponse }
import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory } import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
import akka.http.engine.parsing.HttpResponseParser import akka.http.engine.parsing.{ HttpRequestParser, HttpHeaderParser, HttpResponseParser }
import akka.http.engine.parsing.ParserOutput._ import akka.http.engine.parsing.ParserOutput._
import akka.http.util._ import akka.http.util._
@ -29,10 +28,13 @@ private[http] class HttpClientPipeline(effectiveSettings: ClientConnectionSettin
import effectiveSettings._ import effectiveSettings._
val rootParser = new HttpResponseParser(parserSettings)() // the initial header parser we initially use for every connection,
val warnOnIllegalHeader: ErrorInfo Unit = errorInfo // will not be mutated, all "shared copy" parsers copy on first-write into the header cache
if (parserSettings.illegalHeaderWarnings) val rootParser = new HttpResponseParser(
log.warning(errorInfo.withSummaryPrepended("Illegal response header").formatPretty) parserSettings,
HttpHeaderParser(parserSettings) { errorInfo
if (parserSettings.illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal response header").formatPretty)
})
val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log) val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
@ -60,7 +62,10 @@ private[http] class HttpClientPipeline(effectiveSettings: ClientConnectionSettin
val responsePipeline = val responsePipeline =
Flow[ByteString] Flow[ByteString]
.transform("rootParser", () rootParser.copyWith(warnOnIllegalHeader, requestMethodByPass)) .transform("rootParser", ()
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy(requestMethodByPass))
.splitWhen(_.isInstanceOf[MessageStart]) .splitWhen(_.isInstanceOf[MessageStart])
.headAndTail .headAndTail
.collect { .collect {

View file

@ -50,7 +50,11 @@ private[http] final class BodyPartParser(defaultContentType: ContentType,
// see: http://www.cgjennings.ca/fjs/ and http://ijes.info/4/1/42544103.pdf // see: http://www.cgjennings.ca/fjs/ and http://ijes.info/4/1/42544103.pdf
private[this] val boyerMoore = new BoyerMoore(needle) private[this] val boyerMoore = new BoyerMoore(needle)
private[this] val headerParser = HttpHeaderParser(settings, warnOnIllegalHeader) // TODO: prevent re-priming header parser from scratch // TODO: prevent re-priming header parser from scratch
private[this] val headerParser = HttpHeaderParser(settings) { errorInfo
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
}
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs
private[this] var resultIterator: Iterator[Output] = Iterator.empty private[this] var resultIterator: Iterator[Output] = Iterator.empty
private[this] var state: ByteString StateResult = tryParseInitialBoundary private[this] var state: ByteString StateResult = tryParseInitialBoundary

View file

@ -57,7 +57,7 @@ import akka.http.model.parser.CharacterClasses._
* Since we address them via the nodes MSB and zero is reserved the trie * Since we address them via the nodes MSB and zero is reserved the trie
* cannot hold more then 255 items, so this array has a fixed size of 255. * cannot hold more then 255 items, so this array has a fixed size of 255.
*/ */
private[parsing] final class HttpHeaderParser private ( private[engine] final class HttpHeaderParser private (
val settings: HttpHeaderParser.Settings, val settings: HttpHeaderParser.Settings,
warnOnIllegalHeader: ErrorInfo Unit, warnOnIllegalHeader: ErrorInfo Unit,
private[this] var nodes: Array[Char] = new Array(512), // initial size, can grow as needed private[this] var nodes: Array[Char] = new Array(512), // initial size, can grow as needed
@ -83,7 +83,7 @@ private[parsing] final class HttpHeaderParser private (
/** /**
* Returns a copy of this parser that shares the trie data with this instance. * Returns a copy of this parser that shares the trie data with this instance.
*/ */
def copyWith(warnOnIllegalHeader: ErrorInfo Unit) = def createShallowCopy(): HttpHeaderParser =
new HttpHeaderParser(settings, warnOnIllegalHeader, nodes, nodeCount, branchData, branchDataCount, values, valueCount) new HttpHeaderParser(settings, warnOnIllegalHeader, nodes, nodeCount, branchData, branchDataCount, values, valueCount)
/** /**
@ -402,12 +402,10 @@ private[http] object HttpHeaderParser {
"Cache-Control: no-cache", "Cache-Control: no-cache",
"Expect: 100-continue") "Expect: 100-continue")
private val defaultIllegalHeaderWarning: ErrorInfo Unit = info throw new IllegalHeaderException(info) def apply(settings: HttpHeaderParser.Settings)(warnOnIllegalHeader: ErrorInfo Unit = info throw new IllegalHeaderException(info)) =
def apply(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) =
prime(unprimed(settings, warnOnIllegalHeader)) prime(unprimed(settings, warnOnIllegalHeader))
def unprimed(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) = def unprimed(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo Unit) =
new HttpHeaderParser(settings, warnOnIllegalHeader) new HttpHeaderParser(settings, warnOnIllegalHeader)
def prime(parser: HttpHeaderParser): HttpHeaderParser = { def prime(parser: HttpHeaderParser): HttpHeaderParser = {

View file

@ -10,48 +10,66 @@ import akka.parboiled2.CharUtils
import akka.util.ByteString import akka.util.ByteString
import akka.stream.scaladsl.Source import akka.stream.scaladsl.Source
import akka.stream.stage._ import akka.stream.stage._
import akka.http.Http.StreamException
import akka.http.model.parser.CharacterClasses import akka.http.model.parser.CharacterClasses
import akka.http.util._
import akka.http.model._ import akka.http.model._
import headers._ import headers._
import HttpProtocols._ import HttpProtocols._
import ParserOutput._
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOutput <: ParserOutput](val settings: ParserSettings, private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings,
val headerParser: HttpHeaderParser) val headerParser: HttpHeaderParser)
extends StatefulStage[ByteString, Output] { extends PushPullStage[ByteString, Output] {
import HttpMessageParser._
import settings._ import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult final case class Trampoline(f: ByteString StateResult) extends StateResult
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs private[this] val result = new ListBuffer[Output]
private[this] var state: ByteString StateResult = startNewMessage(_, 0) private[this] var state: ByteString StateResult = startNewMessage(_, 0)
private[this] var protocol: HttpProtocol = `HTTP/1.1` private[this] var protocol: HttpProtocol = `HTTP/1.1`
private[this] var completionHandling: CompletionHandling = CompletionOk
private[this] var terminated = false private[this] var terminated = false
override def initial = new State { override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
override def onPush(input: ByteString, ctx: Context[Output]): Directive = { @tailrec def run(next: ByteString StateResult): StateResult =
result.clear() (try next(input)
catch {
case e: ParsingException failMessageStart(e.status, e.info)
case NotEnoughDataException
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}) match {
case Trampoline(x) run(x)
case x x
}
@tailrec def run(next: ByteString StateResult): StateResult = if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`")
(try next(input) run(state)
catch { pushResultHeadAndFinishOrPull(ctx)
case e: ParsingException fail(e.status, e.info) }
case NotEnoughDataException
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}) match {
case Trampoline(next) run(next)
case x x
}
run(state) def onPull(ctx: Context[Output]): Directive = pushResultHeadAndFinishOrPull(ctx)
val resultIterator = result.iterator
if (terminated) emitAndFinish(resultIterator, ctx) def pushResultHeadAndFinishOrPull(ctx: Context[Output]): Directive =
else emit(resultIterator, ctx) if (result.nonEmpty) {
val head = result.head
result.remove(0) // faster than `ListBuffer::drop`
ctx.push(head)
} else if (terminated) ctx.finish() else ctx.pull()
override def onUpstreamFinish(ctx: Context[Output]) = {
completionHandling() match {
case Some(x) emit(x.asInstanceOf[Output])
case None // nothing to do
} }
terminated = true
if (result.isEmpty) ctx.finish() else ctx.absorbTermination()
} }
def startNewMessage(input: ByteString, offset: Int): StateResult = { def startNewMessage(input: ByteString, offset: Int): StateResult = {
@ -59,6 +77,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
try parseMessage(input, offset) try parseMessage(input, offset)
catch { case NotEnoughDataException continue(input, offset)(_startNewMessage) } catch { case NotEnoughDataException continue(input, offset)(_startNewMessage) }
if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError)
_startNewMessage(input, offset) _startNewMessage(input, offset)
} }
@ -81,65 +100,75 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
@tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: List[HttpHeader] = Nil, @tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: List[HttpHeader] = Nil,
headerCount: Int = 0, ch: Option[Connection] = None, headerCount: Int = 0, ch: Option[Connection] = None,
clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None, clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None,
teh: Option[`Transfer-Encoding`] = None, hh: Boolean = false): StateResult = { teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false,
var lineEnd = 0 hh: Boolean = false): StateResult =
val resultHeader = if (headerCount < maxHeaderCount) {
try { var lineEnd = 0
lineEnd = headerParser.parseHeaderLine(input, lineStart)() val resultHeader =
headerParser.resultHeader try {
} catch { lineEnd = headerParser.parseHeaderLine(input, lineStart)()
case NotEnoughDataException null headerParser.resultHeader
} catch {
case NotEnoughDataException null
}
resultHeader match {
case null continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, ch, clh, cth, teh, e100c, hh))
case HttpHeaderParser.EmptyHeader
val close = HttpMessage.connectionCloseExpected(protocol, ch)
setCompletionHandling(CompletionIsEntityStreamError)
parseEntity(headers, protocol, input, lineEnd, clh, cth, teh, e100c, hh, close)
case h: `Content-Length` clh match {
case None parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, Some(h), cth, teh, e100c, hh)
case Some(`h`) parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, e100c, hh)
case _ failMessageStart("HTTP message must not contain more than one Content-Length header")
}
case h: `Content-Type` cth match {
case None parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, Some(h), teh, e100c, hh)
case Some(`h`) parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, e100c, hh)
case _ failMessageStart("HTTP message must not contain more than one Content-Type header")
}
case h: `Transfer-Encoding` teh match {
case None parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, cth, Some(h), e100c, hh)
case Some(x) parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, Some(x append h.encodings), e100c, hh)
}
case h: Connection ch match {
case None parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, Some(h), clh, cth, teh, e100c, hh)
case Some(x) parseHeaderLines(input, lineEnd, headers, headerCount, Some(x append h.tokens), clh, cth, teh, e100c, hh)
}
case h: Host
if (!hh) parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c, hh = true)
else failMessageStart("HTTP message must not contain more than one Host header")
case h: Expect parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c = true, hh)
case h parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, e100c, hh)
} }
resultHeader match { } else failMessageStart(s"HTTP message contains more than the configured limit of $maxHeaderCount headers")
case null continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, ch, clh, cth, teh, hh))
case HttpHeaderParser.EmptyHeader
val close = HttpMessage.connectionCloseExpected(protocol, ch)
parseEntity(headers, protocol, input, lineEnd, clh, cth, teh, hh, close)
case h: Connection
parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, Some(h), clh, cth, teh, hh)
case h: `Content-Length`
if (clh.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, Some(h), cth, teh, hh)
else fail("HTTP message must not contain more than one Content-Length header")
case h: `Content-Type`
if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, Some(h), teh, hh)
else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, ch, clh, cth, teh, hh)
else fail("HTTP message must not contain more than one Content-Type header")
case h: `Transfer-Encoding`
parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, cth, Some(h), hh)
case h if headerCount < maxHeaderCount
parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, hh || h.isInstanceOf[Host])
case _ fail(s"HTTP message contains more than the configured limit of $maxHeaderCount headers")
}
}
// work-around for compiler complaining about non-tail-recursion if we inline this method // work-around for compiler complaining about non-tail-recursion if we inline this method
def parseHeaderLinesAux(headers: List[HttpHeader], headerCount: Int, ch: Option[Connection], def parseHeaderLinesAux(headers: List[HttpHeader], headerCount: Int, ch: Option[Connection],
clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`],
hh: Boolean)(input: ByteString, lineStart: Int): StateResult = e100c: Boolean, hh: Boolean)(input: ByteString, lineStart: Int): StateResult =
parseHeaderLines(input, lineStart, headers, headerCount, ch, clh, cth, teh, hh) parseHeaderLines(input, lineStart, headers, headerCount, ch, clh, cth, teh, e100c, hh)
def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int,
clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`],
hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult
def parseFixedLengthBody(remainingBodyBytes: Long, def parseFixedLengthBody(remainingBodyBytes: Long,
isLastMessage: Boolean)(input: ByteString, bodyStart: Int): StateResult = { isLastMessage: Boolean)(input: ByteString, bodyStart: Int): StateResult = {
val remainingInputBytes = input.length - bodyStart val remainingInputBytes = input.length - bodyStart
if (remainingInputBytes > 0) { if (remainingInputBytes > 0) {
if (remainingInputBytes < remainingBodyBytes) { if (remainingInputBytes < remainingBodyBytes) {
emit(ParserOutput.EntityPart(input drop bodyStart)) emit(EntityPart(input drop bodyStart))
continue(parseFixedLengthBody(remainingBodyBytes - remainingInputBytes, isLastMessage)) continue(parseFixedLengthBody(remainingBodyBytes - remainingInputBytes, isLastMessage))
} else { } else {
val offset = bodyStart + remainingBodyBytes.toInt val offset = bodyStart + remainingBodyBytes.toInt
emit(ParserOutput.EntityPart(input.slice(bodyStart, offset))) emit(EntityPart(input.slice(bodyStart, offset)))
emit(ParserOutput.MessageEnd) emit(MessageEnd)
setCompletionHandling(CompletionOk)
if (isLastMessage) terminate() if (isLastMessage) terminate()
else startNewMessage(input, offset) else startNewMessage(input, offset)
} }
@ -149,32 +178,38 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
def parseChunk(input: ByteString, offset: Int, isLastMessage: Boolean): StateResult = { def parseChunk(input: ByteString, offset: Int, isLastMessage: Boolean): StateResult = {
@tailrec def parseTrailer(extension: String, lineStart: Int, headers: List[HttpHeader] = Nil, @tailrec def parseTrailer(extension: String, lineStart: Int, headers: List[HttpHeader] = Nil,
headerCount: Int = 0): StateResult = { headerCount: Int = 0): StateResult = {
val lineEnd = headerParser.parseHeaderLine(input, lineStart)() var errorInfo: ErrorInfo = null
headerParser.resultHeader match { val lineEnd =
case HttpHeaderParser.EmptyHeader try headerParser.parseHeaderLine(input, lineStart)()
val lastChunk = catch { case e: ParsingException errorInfo = e.info; 0 }
if (extension.isEmpty && headers.isEmpty) HttpEntity.LastChunk else HttpEntity.LastChunk(extension, headers) if (errorInfo eq null) {
emit(ParserOutput.EntityChunk(lastChunk)) headerParser.resultHeader match {
emit(ParserOutput.MessageEnd) case HttpHeaderParser.EmptyHeader
if (isLastMessage) terminate() val lastChunk =
else startNewMessage(input, lineEnd) if (extension.isEmpty && headers.isEmpty) HttpEntity.LastChunk else HttpEntity.LastChunk(extension, headers)
case header if headerCount < maxHeaderCount emit(EntityChunk(lastChunk))
parseTrailer(extension, lineEnd, header :: headers, headerCount + 1) emit(MessageEnd)
case _ fail(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers") setCompletionHandling(CompletionOk)
} if (isLastMessage) terminate()
else startNewMessage(input, lineEnd)
case header if headerCount < maxHeaderCount
parseTrailer(extension, lineEnd, header :: headers, headerCount + 1)
case _ failEntityStream(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers")
}
} else failEntityStream(errorInfo)
} }
def parseChunkBody(chunkSize: Int, extension: String, cursor: Int): StateResult = def parseChunkBody(chunkSize: Int, extension: String, cursor: Int): StateResult =
if (chunkSize > 0) { if (chunkSize > 0) {
val chunkBodyEnd = cursor + chunkSize val chunkBodyEnd = cursor + chunkSize
def result(terminatorLen: Int) = { def result(terminatorLen: Int) = {
emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
} }
byteChar(input, chunkBodyEnd) match { byteChar(input, chunkBodyEnd) match {
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2) case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2)
case '\n' result(1) case '\n' result(1)
case x fail("Illegal chunk termination") case x failEntityStream("Illegal chunk termination")
} }
} else parseTrailer(extension, cursor) } else parseTrailer(extension, cursor)
@ -186,7 +221,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
case '\n' parseChunkBody(chunkSize, extension, cursor + 1) case '\n' parseChunkBody(chunkSize, extension, cursor + 1)
case _ parseChunkExtensions(chunkSize, cursor + 1)(startIx) case _ parseChunkExtensions(chunkSize, cursor + 1)(startIx)
} }
} else fail(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters") } else failEntityStream(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters")
@tailrec def parseSize(cursor: Int, size: Long): StateResult = @tailrec def parseSize(cursor: Int, size: Long): StateResult =
if (size <= maxChunkSize) { if (size <= maxChunkSize) {
@ -194,9 +229,9 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
case c if CharacterClasses.HEXDIG(c) parseSize(cursor + 1, size * 16 + CharUtils.hexValue(c)) case c if CharacterClasses.HEXDIG(c) parseSize(cursor + 1, size * 16 + CharUtils.hexValue(c))
case ';' if cursor > offset parseChunkExtensions(size.toInt, cursor + 1)() case ';' if cursor > offset parseChunkExtensions(size.toInt, cursor + 1)()
case '\r' if cursor > offset && byteChar(input, cursor + 1) == '\n' parseChunkBody(size.toInt, "", cursor + 2) case '\r' if cursor > offset && byteChar(input, cursor + 1) == '\n' parseChunkBody(size.toInt, "", cursor + 2)
case c fail(s"Illegal character '${escape(c)}' in chunk start") case c failEntityStream(s"Illegal character '${escape(c)}' in chunk start")
} }
} else fail(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes") } else failEntityStream(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes")
try parseSize(offset, 0) try parseSize(offset, 0)
catch { catch {
@ -222,12 +257,21 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
} }
def trampoline(next: ByteString StateResult): StateResult = Trampoline(next) def trampoline(next: ByteString StateResult): StateResult = Trampoline(next)
def fail(summary: String): StateResult = fail(summary, "") def failMessageStart(summary: String): StateResult = failMessageStart(summary, "")
def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail) def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail)
def fail(status: StatusCode): StateResult = fail(status, status.defaultMessage) def failMessageStart(status: StatusCode): StateResult = failMessageStart(status, status.defaultMessage)
def fail(status: StatusCode, summary: String, detail: String = ""): StateResult = fail(status, ErrorInfo(summary, detail)) def failMessageStart(status: StatusCode, summary: String, detail: String = ""): StateResult = failMessageStart(status, ErrorInfo(summary, detail))
def fail(status: StatusCode, info: ErrorInfo): StateResult = { def failMessageStart(status: StatusCode, info: ErrorInfo): StateResult = {
emit(ParserOutput.ParseError(status, info)) emit(MessageStartError(status, info))
setCompletionHandling(CompletionOk)
terminate()
}
def failEntityStream(summary: String): StateResult = failEntityStream(summary, "")
def failEntityStream(summary: String, detail: String): StateResult = failEntityStream(ErrorInfo(summary, detail))
def failEntityStream(info: ErrorInfo): StateResult = {
emit(EntityStreamError(info))
setCompletionHandling(CompletionOk)
terminate() terminate()
} }
@ -250,14 +294,23 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
contentLength: Int)(entityParts: Any): UniversalEntity = contentLength: Int)(entityParts: Any): UniversalEntity =
HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength)) HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength))
def defaultEntity(cth: Option[`Content-Type`], contentLength: Long)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = { def defaultEntity(cth: Option[`Content-Type`],
val data = entityParts.collect { case ParserOutput.EntityPart(bytes) bytes } contentLength: Long,
HttpEntity.Default(contentType(cth), contentLength, data) transformData: Source[ByteString] Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = {
val data = entityParts.collect {
case EntityPart(bytes) bytes
case EntityStreamError(info) throw new StreamException(info)
}
HttpEntity.Default(contentType(cth), contentLength, transformData(data))
} }
def chunkedEntity(cth: Option[`Content-Type`])(entityChunks: Source[_ <: ParserOutput]): RequestEntity with ResponseEntity = { def chunkedEntity(cth: Option[`Content-Type`],
val chunks = entityChunks.collect { case ParserOutput.EntityChunk(chunk) chunk } transformChunks: Source[HttpEntity.ChunkStreamPart] Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = {
HttpEntity.Chunked(contentType(cth), chunks) val chunks = entityChunks.collect {
case EntityChunk(chunk) chunk
case EntityStreamError(info) throw new StreamException(info)
}
HttpEntity.Chunked(contentType(cth), transformChunks(chunks))
} }
def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] =
@ -265,4 +318,16 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
case Some(x) x :: headers case Some(x) x :: headers
case None headers case None headers
} }
def setCompletionHandling(completionHandling: CompletionHandling): Unit =
this.completionHandling = completionHandling
} }
private[http] object HttpMessageParser {
type CompletionHandling = () Option[ParserOutput]
val CompletionOk: CompletionHandling = () None
val CompletionIsMessageStartError: CompletionHandling =
() Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start")))
val CompletionIsEntityStreamError: CompletionHandling =
() Some(ParserOutput.EntityStreamError(ErrorInfo("Entity stream truncation")))
}

View file

@ -6,27 +6,34 @@ package akka.http.engine.parsing
import java.lang.{ StringBuilder JStringBuilder } import java.lang.{ StringBuilder JStringBuilder }
import scala.annotation.tailrec import scala.annotation.tailrec
import akka.http.model.parser.CharacterClasses import akka.actor.ActorRef
import akka.stream.stage.{ Context, PushPullStage }
import akka.stream.scaladsl.Source import akka.stream.scaladsl.Source
import akka.util.ByteString import akka.util.ByteString
import akka.http.engine.server.OneHundredContinue
import akka.http.model.parser.CharacterClasses
import akka.http.util.identityFunc
import akka.http.model._ import akka.http.model._
import headers._ import headers._
import StatusCodes._ import StatusCodes._
import ParserOutput._
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] class HttpRequestParser(_settings: ParserSettings, private[http] class HttpRequestParser(_settings: ParserSettings,
rawRequestUriHeader: Boolean)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings)) rawRequestUriHeader: Boolean,
extends HttpMessageParser[ParserOutput.RequestOutput](_settings, _headerParser) { _headerParser: HttpHeaderParser,
oneHundredContinueRef: () Option[ActorRef] = () None)
extends HttpMessageParser[RequestOutput](_settings, _headerParser) {
import settings._ import settings._
private[this] var method: HttpMethod = _ private[this] var method: HttpMethod = _
private[this] var uri: Uri = _ private[this] var uri: Uri = _
private[this] var uriBytes: Array[Byte] = _ private[this] var uriBytes: Array[Byte] = _
def copyWith(warnOnIllegalHeader: ErrorInfo Unit): HttpRequestParser = def createShallowCopy(oneHundredContinueRef: () Option[ActorRef]): HttpRequestParser =
new HttpRequestParser(settings, rawRequestUriHeader)(headerParser.copyWith(warnOnIllegalHeader)) new HttpRequestParser(settings, rawRequestUriHeader, headerParser.createShallowCopy(), oneHundredContinueRef)
def parseMessage(input: ByteString, offset: Int): StateResult = { def parseMessage(input: ByteString, offset: Int): StateResult = {
var cursor = parseMethod(input, offset) var cursor = parseMethod(input, offset)
@ -107,16 +114,32 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
// http://tools.ietf.org/html/rfc7230#section-3.3 // http://tools.ietf.org/html/rfc7230#section-3.3
def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int,
clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`],
hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult =
if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) {
def emitRequestStart(createEntity: Source[ParserOutput.RequestOutput] RequestEntity, def emitRequestStart(createEntity: Source[RequestOutput] RequestEntity,
headers: List[HttpHeader] = headers) = { headers: List[HttpHeader] = headers) = {
val allHeaders = val allHeaders =
if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers
else headers else headers
emit(ParserOutput.RequestStart(method, uri, protocol, allHeaders, createEntity, closeAfterResponseCompletion)) emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion))
} }
def expect100continueHandling[T]: Source[T] Source[T] =
if (expect100continue) {
_.transform("expect100continueTrigger", () new PushPullStage[T, T] {
private var oneHundredContinueSent = false
def onPush(elem: T, ctx: Context[T]) = ctx.push(elem)
def onPull(ctx: Context[T]) = {
if (!oneHundredContinueSent) {
val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable"))
ref ! OneHundredContinue
oneHundredContinueSent = true
}
ctx.pull()
}
})
} else identityFunc
teh match { teh match {
case None case None
val contentLength = clh match { val contentLength = clh match {
@ -124,17 +147,19 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
case None 0 case None 0
} }
if (contentLength > maxContentLength) if (contentLength > maxContentLength)
fail(RequestEntityTooLarge, failMessageStart(RequestEntityTooLarge,
s"Request Content-Length $contentLength exceeds the configured limit of $maxContentLength") s"Request Content-Length $contentLength exceeds the configured limit of $maxContentLength")
else if (contentLength == 0) { else if (contentLength == 0) {
emitRequestStart(emptyEntity(cth)) emitRequestStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk)
startNewMessage(input, bodyStart) startNewMessage(input, bodyStart)
} else if (contentLength <= input.size - bodyStart) { } else if (contentLength <= input.size - bodyStart) {
val cl = contentLength.toInt val cl = contentLength.toInt
emitRequestStart(strictEntity(cth, input, bodyStart, cl)) emitRequestStart(strictEntity(cth, input, bodyStart, cl))
setCompletionHandling(HttpMessageParser.CompletionOk)
startNewMessage(input, bodyStart + cl) startNewMessage(input, bodyStart + cl)
} else { } else {
emitRequestStart(defaultEntity(cth, contentLength)) emitRequestStart(defaultEntity(cth, contentLength, expect100continueHandling))
parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart)
} }
@ -142,11 +167,11 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te) val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te)
if (te.isChunked) { if (te.isChunked) {
if (clh.isEmpty) { if (clh.isEmpty) {
emitRequestStart(chunkedEntity(cth), completedHeaders) emitRequestStart(chunkedEntity(cth, expect100continueHandling), completedHeaders)
parseChunk(input, bodyStart, closeAfterResponseCompletion) parseChunk(input, bodyStart, closeAfterResponseCompletion)
} else fail("A chunked request must not contain a Content-Length header.") } else failMessageStart("A chunked request must not contain a Content-Length header.")
} else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, hostHeaderPresent, } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None,
closeAfterResponseCompletion) expect100continue, hostHeaderPresent, closeAfterResponseCompletion)
} }
} else fail("Request is missing required `Host` header") } else failMessageStart("Request is missing required `Host` header")
} }

View file

@ -11,20 +11,22 @@ import akka.util.ByteString
import akka.http.model._ import akka.http.model._
import headers._ import headers._
import HttpResponseParser.NoMethod import HttpResponseParser.NoMethod
import ParserOutput._
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] class HttpResponseParser(_settings: ParserSettings, private[http] class HttpResponseParser(_settings: ParserSettings,
dequeueRequestMethodForNextResponse: () HttpMethod = () NoMethod)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings)) _headerParser: HttpHeaderParser,
extends HttpMessageParser[ParserOutput.ResponseOutput](_settings, _headerParser) { dequeueRequestMethodForNextResponse: () HttpMethod = () NoMethod)
extends HttpMessageParser[ResponseOutput](_settings, _headerParser) {
import settings._ import settings._
private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod
private[this] var statusCode: StatusCode = StatusCodes.OK private[this] var statusCode: StatusCode = StatusCodes.OK
def copyWith(warnOnIllegalHeader: ErrorInfo Unit, dequeueRequestMethodForNextResponse: () HttpMethod): HttpResponseParser = def createShallowCopy(dequeueRequestMethodForNextResponse: () HttpMethod): HttpResponseParser =
new HttpResponseParser(settings, dequeueRequestMethodForNextResponse)(headerParser.copyWith(warnOnIllegalHeader)) new HttpResponseParser(settings, headerParser.createShallowCopy(), dequeueRequestMethodForNextResponse)
override def startNewMessage(input: ByteString, offset: Int): StateResult = { override def startNewMessage(input: ByteString, offset: Int): StateResult = {
requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse() requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse()
@ -39,7 +41,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
cursor = parseReason(input, cursor)() cursor = parseReason(input, cursor)()
parseHeaderLines(input, cursor) parseHeaderLines(input, cursor)
} else badProtocol } else badProtocol
} else fail("Unexpected server response", input.drop(offset).utf8String) } else failMessageStart("Unexpected server response", input.drop(offset).utf8String)
def badProtocol = throw new ParsingException("The server-side HTTP version is not supported") def badProtocol = throw new ParsingException("The server-side HTTP version is not supported")
@ -72,12 +74,13 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
// http://tools.ietf.org/html/rfc7230#section-3.3 // http://tools.ietf.org/html/rfc7230#section-3.3
def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int,
clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`],
hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = {
def emitResponseStart(createEntity: Source[ParserOutput.ResponseOutput] ResponseEntity, def emitResponseStart(createEntity: Source[ResponseOutput] ResponseEntity,
headers: List[HttpHeader] = headers) = headers: List[HttpHeader] = headers) =
emit(ParserOutput.ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion)) emit(ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion))
def finishEmptyResponse() = { def finishEmptyResponse() = {
emitResponseStart(emptyEntity(cth)) emitResponseStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk)
startNewMessage(input, bodyStart) startNewMessage(input, bodyStart)
} }
@ -86,11 +89,12 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
case None clh match { case None clh match {
case Some(`Content-Length`(contentLength)) case Some(`Content-Length`(contentLength))
if (contentLength > maxContentLength) if (contentLength > maxContentLength)
fail(s"Response Content-Length $contentLength exceeds the configured limit of $maxContentLength") failMessageStart(s"Response Content-Length $contentLength exceeds the configured limit of $maxContentLength")
else if (contentLength == 0) finishEmptyResponse() else if (contentLength == 0) finishEmptyResponse()
else if (contentLength < input.size - bodyStart) { else if (contentLength < input.size - bodyStart) {
val cl = contentLength.toInt val cl = contentLength.toInt
emitResponseStart(strictEntity(cth, input, bodyStart, cl)) emitResponseStart(strictEntity(cth, input, bodyStart, cl))
setCompletionHandling(HttpMessageParser.CompletionOk)
startNewMessage(input, bodyStart + cl) startNewMessage(input, bodyStart + cl)
} else { } else {
emitResponseStart(defaultEntity(cth, contentLength)) emitResponseStart(defaultEntity(cth, contentLength))
@ -98,9 +102,10 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
} }
case None case None
emitResponseStart { entityParts emitResponseStart { entityParts
val data = entityParts.collect { case ParserOutput.EntityPart(bytes) bytes } val data = entityParts.collect { case EntityPart(bytes) bytes }
HttpEntity.CloseDelimited(contentType(cth), data) HttpEntity.CloseDelimited(contentType(cth), data)
} }
setCompletionHandling(HttpMessageParser.CompletionOk)
parseToCloseBody(input, bodyStart) parseToCloseBody(input, bodyStart)
} }
@ -110,9 +115,9 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
if (clh.isEmpty) { if (clh.isEmpty) {
emitResponseStart(chunkedEntity(cth), completedHeaders) emitResponseStart(chunkedEntity(cth), completedHeaders)
parseChunk(input, bodyStart, closeAfterResponseCompletion) parseChunk(input, bodyStart, closeAfterResponseCompletion)
} else fail("A chunked response must not contain a Content-Length header.") } else failMessageStart("A chunked response must not contain a Content-Length header.")
} else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, hostHeaderPresent, } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None,
closeAfterResponseCompletion) expect100continue, hostHeaderPresent, closeAfterResponseCompletion)
} }
} else finishEmptyResponse() } else finishEmptyResponse()
} }
@ -120,7 +125,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
// currently we do not check for `settings.maxContentLength` overflow // currently we do not check for `settings.maxContentLength` overflow
def parseToCloseBody(input: ByteString, bodyStart: Int): StateResult = { def parseToCloseBody(input: ByteString, bodyStart: Int): StateResult = {
if (input.length > bodyStart) if (input.length > bodyStart)
emit(ParserOutput.EntityPart(input drop bodyStart)) emit(EntityPart(input drop bodyStart))
continue(parseToCloseBody) continue(parseToCloseBody)
} }
} }

View file

@ -28,6 +28,7 @@ private[http] object ParserOutput {
protocol: HttpProtocol, protocol: HttpProtocol,
headers: List[HttpHeader], headers: List[HttpHeader],
createEntity: Source[RequestOutput] RequestEntity, createEntity: Source[RequestOutput] RequestEntity,
expect100ContinueResponsePending: Boolean,
closeAfterResponseCompletion: Boolean) extends MessageStart with RequestOutput closeAfterResponseCompletion: Boolean) extends MessageStart with RequestOutput
final case class ResponseStart( final case class ResponseStart(
@ -43,5 +44,7 @@ private[http] object ParserOutput {
final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput
final case class ParseError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput
final case class EntityStreamError(info: ErrorInfo) extends MessageOutput
} }

View file

@ -98,23 +98,21 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.`
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
} }
def renderContentLength(contentLength: Long): Unit = { def renderContentLength(contentLength: Long) =
if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf if (method.isEntityAccepted) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
r ~~ CrLf
}
def completeRequestRendering(): Source[ByteString] = def completeRequestRendering(): Source[ByteString] =
entity match { entity match {
case x if x.isKnownEmpty case x if x.isKnownEmpty
renderContentLength(0) renderContentLength(0) ~~ CrLf
Source.singleton(r.get) Source.singleton(r.get)
case HttpEntity.Strict(_, data) case HttpEntity.Strict(_, data)
renderContentLength(data.length) renderContentLength(data.length) ~~ CrLf
Source.singleton(r.get ++ data) Source.singleton(r.get ++ data)
case HttpEntity.Default(_, contentLength, data) case HttpEntity.Default(_, contentLength, data)
renderContentLength(contentLength) renderContentLength(contentLength) ~~ CrLf
renderByteStrings(r, renderByteStrings(r,
data.transform("checkContentLength", () new CheckContentLengthTransformer(contentLength))) data.transform("checkContentLength", () new CheckContentLengthTransformer(contentLength)))

View file

@ -136,6 +136,9 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
} }
def renderContentLengthHeader(contentLength: Long) =
if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
def byteStrings(entityBytes: Source[ByteString]): Source[ByteString] = def byteStrings(entityBytes: Source[ByteString]): Source[ByteString] =
renderByteStrings(r, entityBytes, skipEntity = noEntity) renderByteStrings(r, entityBytes, skipEntity = noEntity)
@ -144,20 +147,19 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
case HttpEntity.Strict(_, data) case HttpEntity.Strict(_, data)
renderHeaders(headers.toList) renderHeaders(headers.toList)
renderEntityContentType(r, entity) renderEntityContentType(r, entity)
r ~~ `Content-Length` ~~ data.length ~~ CrLf ~~ CrLf renderContentLengthHeader(data.length) ~~ CrLf
val entityBytes = if (noEntity) ByteString.empty else data val entityBytes = if (noEntity) ByteString.empty else data
Source.singleton(r.get ++ entityBytes) Source.singleton(r.get ++ entityBytes)
case HttpEntity.Default(_, contentLength, data) case HttpEntity.Default(_, contentLength, data)
renderHeaders(headers.toList) renderHeaders(headers.toList)
renderEntityContentType(r, entity) renderEntityContentType(r, entity)
r ~~ `Content-Length` ~~ contentLength ~~ CrLf ~~ CrLf renderContentLengthHeader(contentLength) ~~ CrLf
byteStrings(data.transform("checkContentLength", () new CheckContentLengthTransformer(contentLength))) byteStrings(data.transform("checkContentLength", () new CheckContentLengthTransformer(contentLength)))
case HttpEntity.CloseDelimited(_, data) case HttpEntity.CloseDelimited(_, data)
renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD) renderHeaders(headers.toList, alwaysClose = ctx.requestMethod != HttpMethods.HEAD)
renderEntityContentType(r, entity) renderEntityContentType(r, entity) ~~ CrLf
r ~~ CrLf
byteStrings(data) byteStrings(data)
case HttpEntity.Chunked(contentType, chunks) case HttpEntity.Chunked(contentType, chunks)
@ -165,8 +167,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
completeResponseRendering(HttpEntity.CloseDelimited(contentType, chunks.map(_.data))) completeResponseRendering(HttpEntity.CloseDelimited(contentType, chunks.map(_.data)))
else { else {
renderHeaders(headers.toList) renderHeaders(headers.toList)
renderEntityContentType(r, entity) renderEntityContentType(r, entity) ~~ CrLf
r ~~ CrLf
byteStrings(chunks.transform("renderChunks", () new ChunkTransformer)) byteStrings(chunks.transform("renderChunks", () new ChunkTransformer))
} }
} }

View file

@ -40,9 +40,9 @@ private object RenderSupport {
} }
} }
def renderEntityContentType(r: Rendering, entity: HttpEntity): Unit = def renderEntityContentType(r: Rendering, entity: HttpEntity) =
if (entity.contentType != ContentTypes.NoContentType) if (entity.contentType != ContentTypes.NoContentType) r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf
r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf else r
def renderByteStrings(r: ByteStringRendering, entityBytes: Source[ByteString], def renderByteStrings(r: ByteStringRendering, entityBytes: Source[ByteString],
skipEntity: Boolean = false): Source[ByteString] = { skipEntity: Boolean = false): Source[ByteString] = {

View file

@ -4,30 +4,39 @@
package akka.http.engine.server package akka.http.engine.server
import akka.actor.{ Props, ActorRef }
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.stream.stage.PushPullStage
import akka.util.ByteString
import akka.stream.io.StreamTcp import akka.stream.io.StreamTcp
import akka.stream.FlattenStrategy import akka.stream.FlattenStrategy
import akka.stream.FlowMaterializer import akka.stream.FlowMaterializer
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.stage._ import akka.http.engine.parsing.{ HttpHeaderParser, HttpRequestParser }
import akka.http.engine.parsing.HttpRequestParser
import akka.http.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory } import akka.http.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory }
import akka.http.model.{ StatusCode, ErrorInfo, HttpRequest, HttpResponse, HttpMethods } import akka.http.model._
import akka.http.engine.parsing.ParserOutput._ import akka.http.engine.parsing.ParserOutput._
import akka.http.Http import akka.http.Http
import akka.http.util._ import akka.http.util._
import akka.util.ByteString
import scala.util.control.NonFatal
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAdapter)(implicit fm: FlowMaterializer) private[http] class HttpServerPipeline(settings: ServerSettings,
log: LoggingAdapter)(implicit fm: FlowMaterializer)
extends (StreamTcp.IncomingTcpConnection Http.IncomingConnection) { extends (StreamTcp.IncomingTcpConnection Http.IncomingConnection) {
import settings.parserSettings
val rootParser = new HttpRequestParser(settings.parserSettings, settings.rawRequestUriHeader)() // the initial header parser we initially use for every connection,
val warnOnIllegalHeader: ErrorInfo Unit = errorInfo // will not be mutated, all "shared copy" parsers copy on first-write into the header cache
if (settings.parserSettings.illegalHeaderWarnings) val rootParser = new HttpRequestParser(
log.warning(errorInfo.withSummaryPrepended("Illegal request header").formatPretty) parserSettings,
settings.rawRequestUriHeader,
HttpHeaderParser(parserSettings) { errorInfo
if (parserSettings.illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal request header").formatPretty)
})
val responseRendererFactory = new HttpResponseRendererFactory(settings.serverHeader, settings.responseHeaderSizeHint, log) val responseRendererFactory = new HttpResponseRendererFactory(settings.serverHeader, settings.responseHeaderSizeHint, log)
@ -40,88 +49,125 @@ private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAda
val userIn = Sink.publisher[HttpRequest] val userIn = Sink.publisher[HttpRequest]
val userOut = Source.subscriber[HttpResponse] val userOut = Source.subscriber[HttpResponse]
val pipeline = FlowGraph { implicit b val oneHundredContinueSource = Source[OneHundredContinue.type](Props[OneHundredContinueSourceActor])
val bypassFanout = Broadcast[(RequestOutput, Source[RequestOutput])]("bypassFanout") @volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168
val bypassFanin = Merge[Any]("merge")
val rootParsePipeline = val pipeline = FlowGraph { implicit b
Flow[ByteString] val bypassFanout = Broadcast[RequestOutput]("bypassFanout")
.transform("rootParser", () rootParser.copyWith(warnOnIllegalHeader)) val bypassMerge = new BypassMerge
val requestParsing = Flow[ByteString].transform("rootParser", ()
// each connection uses a single (private) request parser instance for all its requests
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy(() oneHundredContinueRef))
val requestPreparation =
Flow[RequestOutput]
.splitWhen(x x.isInstanceOf[MessageStart] || x == MessageEnd) .splitWhen(x x.isInstanceOf[MessageStart] || x == MessageEnd)
.headAndTail .headAndTail
.collect {
case (RequestStart(method, uri, protocol, headers, createEntity, _, _), entityParts)
val effectiveUri = HttpRequest.effectiveUri(uri, headers, securedConnection = false, settings.defaultHostHeader)
val effectiveMethod = if (method == HttpMethods.HEAD && settings.transparentHeadRequests) HttpMethods.GET else method
HttpRequest(effectiveMethod, effectiveUri, headers, createEntity(entityParts), protocol)
}
val rendererPipeline = val rendererPipeline =
Flow[Any] Flow[ResponseRenderingContext]
.transform("applyApplicationBypass", () applyApplicationBypass) .transform("recover", () new ErrorsTo500ResponseRecovery(log)) // FIXME: simplify after #16394 is closed
.transform("renderer", () responseRendererFactory.newRenderer) .transform("renderer", () responseRendererFactory.newRenderer)
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
.transform("errorLogger", () errorLogger(log, "Outgoing response stream error")) .transform("errorLogger", () errorLogger(log, "Outgoing response stream error"))
val requestTweaking = Flow[(RequestOutput, Source[RequestOutput])].collect {
case (RequestStart(method, uri, protocol, headers, createEntity, _), entityParts)
val effectiveUri = HttpRequest.effectiveUri(uri, headers, securedConnection = false, settings.defaultHostHeader)
val effectiveMethod = if (method == HttpMethods.HEAD && settings.transparentHeadRequests) HttpMethods.GET else method
HttpRequest(effectiveMethod, effectiveUri, headers, createEntity(entityParts), protocol)
}
val bypass =
Flow[(RequestOutput, Source[RequestOutput])]
.collect[MessageStart with RequestOutput] { case (x: MessageStart, _) x }
//FIXME: the graph is unnecessary after fixing #15957 //FIXME: the graph is unnecessary after fixing #15957
networkIn ~> rootParsePipeline ~> bypassFanout ~> requestTweaking ~> userIn networkIn ~> requestParsing ~> bypassFanout ~> requestPreparation ~> userIn
bypassFanout ~> bypass ~> bypassFanin bypassFanout ~> bypassMerge.bypassInput
userOut ~> bypassFanin ~> rendererPipeline ~> networkOut userOut ~> bypassMerge.applicationInput ~> rendererPipeline ~> networkOut
oneHundredContinueSource ~> bypassMerge.oneHundredContinueInput
}.run() }.run()
oneHundredContinueRef = Some(pipeline.get(oneHundredContinueSource))
Http.IncomingConnection(tcpConn.remoteAddress, pipeline.get(userIn), pipeline.get(userOut)) Http.IncomingConnection(tcpConn.remoteAddress, pipeline.get(userIn), pipeline.get(userOut))
} }
/** class BypassMerge extends FlexiMerge[ResponseRenderingContext]("BypassMerge") {
* Combines the HttpResponse coming in from the application with the ParserOutput.RequestStart import FlexiMerge._
* produced by the request parser into a ResponseRenderingContext. val bypassInput = createInputPort[RequestOutput]()
* If the parser produced a ParserOutput.ParseError the error response is immediately dispatched to downstream. val oneHundredContinueInput = createInputPort[OneHundredContinue.type]()
*/ val applicationInput = createInputPort[HttpResponse]()
def applyApplicationBypass =
new PushStage[Any, ResponseRenderingContext] {
var applicationResponse: HttpResponse = _
var requestStart: RequestStart = _
override def onPush(elem: Any, ctx: Context[ResponseRenderingContext]): Directive = elem match { def createMergeLogic() = new MergeLogic[ResponseRenderingContext] {
case response: HttpResponse override def inputHandles(inputCount: Int) = {
requestStart match { require(inputCount == 3, s"BypassMerge must have 3 connected inputs, was $inputCount")
case null Vector(bypassInput, oneHundredContinueInput, applicationInput)
applicationResponse = response
ctx.pull()
case x: RequestStart
requestStart = null
ctx.push(dispatch(x, response))
}
case requestStart: RequestStart
applicationResponse match {
case null
this.requestStart = requestStart
ctx.pull()
case response
applicationResponse = null
ctx.push(dispatch(requestStart, response))
}
case ParseError(status, info)
ctx.push(errorResponse(status, info))
} }
def dispatch(requestStart: RequestStart, response: HttpResponse): ResponseRenderingContext = { override val initialState = State[Any](Read(bypassInput)) {
import requestStart._ case (ctx, _, requestStart: RequestStart) waitingForApplicationResponse(requestStart)
ResponseRenderingContext(response, method, protocol, closeAfterResponseCompletion) case (ctx, _, MessageStartError(status, info)) finishWithError(ctx, "request", status, info)
case _ SameState // drop other parser output
} }
def errorResponse(status: StatusCode, info: ErrorInfo): ResponseRenderingContext = { def waitingForApplicationResponse(requestStart: RequestStart): State[Any] =
log.warning("Illegal request, responding with status '{}': {}", status, info.formatPretty) State[Any](ReadAny(oneHundredContinueInput, applicationInput)) {
case (ctx, _, response: HttpResponse)
// see the comment on [[OneHundredContinue]] for an explanation of the closing logic here (and more)
val close = requestStart.closeAfterResponseCompletion || requestStart.expect100ContinueResponsePending
ctx.emit(ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close))
if (close) finish(ctx) else initialState
case (ctx, _, OneHundredContinue)
assert(requestStart.expect100ContinueResponsePending)
ctx.emit(ResponseRenderingContext(HttpResponse(StatusCodes.Continue)))
waitingForApplicationResponse(requestStart.copy(expect100ContinueResponsePending = false))
}
override def initialCompletionHandling = CompletionHandling(
onComplete = (ctx, _) { ctx.complete(); SameState },
onError = {
case (ctx, _, error: Http.StreamException)
// the application has forwarded a request entity stream error to the response stream
finishWithError(ctx, "request", StatusCodes.BadRequest, error.info)
case (ctx, _, error)
ctx.error(error)
SameState
})
def finishWithError(ctx: MergeLogicContext, target: String, status: StatusCode, info: ErrorInfo): State[Any] = {
log.warning("Illegal {}, responding with status '{}': {}", target, status, info.formatPretty)
val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary
ResponseRenderingContext(HttpResponse(status, entity = msg), closeAfterResponseCompletion = true) ctx.emit(ResponseRenderingContext(HttpResponse(status, entity = msg), closeAfterResponseCompletion = true))
finish(ctx)
}
def finish(ctx: MergeLogicContext): State[Any] = {
ctx.complete() // shouldn't this return a `State` rather than `Unit`?
SameState // it seems weird to stay in the same state after completion
} }
} }
}
} }
private[server] class ErrorsTo500ResponseRecovery(log: LoggingAdapter)
extends PushPullStage[ResponseRenderingContext, ResponseRenderingContext] {
import akka.stream.stage.Context
private[this] var errorResponse: ResponseRenderingContext = _
override def onPush(elem: ResponseRenderingContext, ctx: Context[ResponseRenderingContext]) = ctx.push(elem)
override def onPull(ctx: Context[ResponseRenderingContext]) =
if (ctx.isFinishing) ctx.pushAndFinish(errorResponse)
else ctx.pull()
override def onUpstreamFailure(error: Throwable, ctx: Context[ResponseRenderingContext]) =
error match {
case NonFatal(e)
log.error(e, "Internal server error, sending 500 response")
errorResponse = ResponseRenderingContext(HttpResponse(StatusCodes.InternalServerError),
closeAfterResponseCompletion = true)
ctx.absorbTermination()
case _ ctx.fail(error)
}
}

View file

@ -0,0 +1,66 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.server
import scala.annotation.tailrec
import akka.stream.actor.{ ActorPublisherMessage, ActorPublisher }
/**
* The `Expect: 100-continue` header has a special status in HTTP.
* It allows the client to send an `Expect: 100-continue` header with the request and then pause request sending
* (i.e. hold back sending the request entity). The server reads the request headers, determines whether it wants to
* accept the request and responds with
*
* - `417 Expectation Failed`, if it doesn't support the `100-continue` expectation
* (or if the `Expect` header contains other, unsupported expectations).
* - a `100 Continue` response,
* if it is ready to accept the request entity and the client should go ahead with sending it
* - a final response (like a 4xx to signal some client-side error
* (e.g. if the request entity length is beyond the configured limit) or a 3xx redirect)
*
* Only if the client receives a `100 Continue` response from the server is it allowed to continue sending the request
* entity. In this case it will receive another response after having completed request sending.
* So this special feature breaks the normal "one request - one response" logic of HTTP!
* It therefore requires special handling in all HTTP stacks (client- and server-side).
*
* For us this means:
*
* - on the server-side:
* After having read a `Expect: 100-continue` header with the request we package up an `HttpRequest` instance and send
* it through to the application. Only when (and if) the application then requests data from the entity stream do we
* send out a `100 Continue` response and continue reading the request entity.
* The application can therefore determine itself whether it wants the client to send the request entity
* by deciding whether to look at the request entity data stream or not.
* If the application sends a response *without* having looked at the request entity the client receives this
* response *instead of* the `100 Continue` response and the server closes the connection afterwards.
*
* - on the client-side:
* If the user adds a `Expect: 100-continue` header to the request we need to hold back sending the entity until
* we've received a `100 Continue` response.
*/
private[engine] case object OneHundredContinue
private[engine] class OneHundredContinueSourceActor extends ActorPublisher[OneHundredContinue.type] {
private var triggered = 0
def receive = {
case OneHundredContinue
triggered += 1
tryDispatch()
case ActorPublisherMessage.Request(_)
tryDispatch()
case ActorPublisherMessage.Cancel
context.stop(self)
}
@tailrec private def tryDispatch(): Unit =
if (triggered > 0 && totalDemand > 0) {
onNext(OneHundredContinue)
triggered -= 1
tryDispatch()
}
}

View file

@ -5,6 +5,8 @@
package akka.http.model package akka.http.model
import java.lang.{ Iterable JIterable } import java.lang.{ Iterable JIterable }
import akka.parboiled2.CharUtils
import scala.concurrent.duration.FiniteDuration import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ Future, ExecutionContext } import scala.concurrent.{ Future, ExecutionContext }
import scala.collection.immutable import scala.collection.immutable
@ -129,9 +131,9 @@ final case class HttpRequest(method: HttpMethod = HttpMethods.GET,
headers: immutable.Seq[HttpHeader] = Nil, headers: immutable.Seq[HttpHeader] = Nil,
entity: RequestEntity = HttpEntity.Empty, entity: RequestEntity = HttpEntity.Empty,
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpRequest with HttpMessage { protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpRequest with HttpMessage {
require(!uri.isEmpty, "An HttpRequest must not have an empty Uri") HttpRequest.verifyUri(uri)
require(entity.isKnownEmpty || method.isEntityAccepted, "Requests with this method must have an empty entity") require(entity.isKnownEmpty || method.isEntityAccepted, "Requests with this method must have an empty entity")
require(protocol == HttpProtocols.`HTTP/1.1` || !entity.isInstanceOf[HttpEntity.Chunked], require(protocol != HttpProtocols.`HTTP/1.0` || !entity.isInstanceOf[HttpEntity.Chunked],
"HTTP/1.0 requests must not have a chunked entity") "HTTP/1.0 requests must not have a chunked entity")
type Self = HttpRequest type Self = HttpRequest
@ -281,6 +283,22 @@ object HttpRequest {
else throw new IllegalUriException(s"'Host' header value of request to `$uri` doesn't match request target authority", else throw new IllegalUriException(s"'Host' header value of request to `$uri` doesn't match request target authority",
s"Host header: $hostHeader\nrequest target authority: ${uri.authority}") s"Host header: $hostHeader\nrequest target authority: ${uri.authority}")
} }
/**
* Verifies that the given [[Uri]] is non-empty and has either scheme `http`, `https` or no scheme at all.
* If any of these conditions is not met the method throws an [[IllegalArgumentException]].
*/
def verifyUri(uri: Uri): Unit =
if (uri.isEmpty) throw new IllegalArgumentException("`uri` must not be empty")
else {
def c(i: Int) = CharUtils.toLowerCase(uri.scheme charAt i)
uri.scheme.length match {
case 0 // ok
case 4 if c(0) == 'h' && c(1) == 't' && c(2) == 't' && c(3) == 'p' // ok
case 5 if c(0) == 'h' && c(1) == 't' && c(2) == 't' && c(3) == 'p' && c(4) == 's' // ok
case _ throw new IllegalArgumentException("""`uri` must have scheme "http", "https" or no scheme""")
}
}
} }
/** /**
@ -290,6 +308,10 @@ final case class HttpResponse(status: StatusCode = StatusCodes.OK,
headers: immutable.Seq[HttpHeader] = Nil, headers: immutable.Seq[HttpHeader] = Nil,
entity: ResponseEntity = HttpEntity.Empty, entity: ResponseEntity = HttpEntity.Empty,
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpResponse with HttpMessage { protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) extends japi.HttpResponse with HttpMessage {
require(entity.isKnownEmpty || status.allowsEntity, "Responses with this status code must have an empty entity")
require(protocol == HttpProtocols.`HTTP/1.1` || !entity.isInstanceOf[HttpEntity.Chunked],
"HTTP/1.0 responses must not have a chunked entity")
type Self = HttpResponse type Self = HttpResponse
def self = this def self = this

View file

@ -8,8 +8,6 @@ package headers
import java.lang.Iterable import java.lang.Iterable
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util import java.util
import akka.http.model.japi
import akka.http.model.japi.JavaMapping
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.collection.immutable import scala.collection.immutable
import akka.http.util._ import akka.http.util._
@ -54,6 +52,7 @@ final case class Connection(tokens: immutable.Seq[String]) extends ModeledHeader
def renderValue[R <: Rendering](r: R): r.type = r ~~ tokens def renderValue[R <: Rendering](r: R): r.type = r ~~ tokens
def hasClose = has("close") def hasClose = has("close")
def hasKeepAlive = has("keep-alive") def hasKeepAlive = has("keep-alive")
def append(tokens: immutable.Seq[String]) = Connection(this.tokens ++ tokens)
@tailrec private def has(item: String, ix: Int = 0): Boolean = @tailrec private def has(item: String, ix: Int = 0): Boolean =
if (ix < tokens.length) if (ix < tokens.length)
if (tokens(ix) equalsIgnoreCase item) true if (tokens(ix) equalsIgnoreCase item) true
@ -563,6 +562,7 @@ final case class `Transfer-Encoding`(encodings: immutable.Seq[TransferEncoding])
case remaining Some(`Transfer-Encoding`(remaining)) case remaining Some(`Transfer-Encoding`(remaining))
} }
} else Some(this) } else Some(this)
def append(encodings: immutable.Seq[TransferEncoding]) = `Transfer-Encoding`(this.encodings ++ encodings)
def renderValue[R <: Rendering](r: R): r.type = r ~~ encodings def renderValue[R <: Rendering](r: R): r.type = r ~~ encodings
protected def companion = `Transfer-Encoding` protected def companion = `Transfer-Encoding`

View file

@ -22,7 +22,7 @@ import HttpMethods._
import HttpProtocols._ import HttpProtocols._
import StatusCodes._ import StatusCodes._
import HttpEntity._ import HttpEntity._
import ParserOutput.ParseError import ParserOutput._
import FastFuture._ import FastFuture._
class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
@ -48,13 +48,14 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
"with no headers and no body but remaining content" in new Test { "with no headers and no body but remaining content" in new Test {
"""GET / HTTP/1.0 Seq("""GET / HTTP/1.0
| |
|POST /foo HTTP/1.0 |POST /foo HTTP/1.0
| |
|TRA""" /* beginning of TRACE request */ should parseTo( |TRA""") /* beginning of TRACE request */ should generalMultiParseTo(
HttpRequest(GET, "/", protocol = `HTTP/1.0`), Right(HttpRequest(GET, "/", protocol = `HTTP/1.0`)),
HttpRequest(POST, "/foo", protocol = `HTTP/1.0`)) Right(HttpRequest(POST, "/foo", protocol = `HTTP/1.0`)),
Left(MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start"))))
closeAfterResponseCompletion shouldEqual Seq(true, true) closeAfterResponseCompletion shouldEqual Seq(true, true)
} }
@ -168,7 +169,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
"request start" in new Test { "request start" in new Test {
Seq(start, "rest") should generalMultiParseTo( Seq(start, "rest") should generalMultiParseTo(
Right(baseRequest.withEntity(HttpEntity.Chunked(`application/pdf`, source()))), Right(baseRequest.withEntity(HttpEntity.Chunked(`application/pdf`, source()))),
Left(ParseError(400: StatusCode, ErrorInfo("Illegal character 'r' in chunk start")))) Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -182,15 +183,18 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
"10;foo=", "10;foo=",
"""bar """bar
|0123456789ABCDEF |0123456789ABCDEF
|10 |A
|0123456789""", |0123456789""",
"""ABCDEF """
|dead""") should generalMultiParseTo( |0
|
|""") should generalMultiParseTo(
Right(baseRequest.withEntity(Chunked(`application/pdf`, source( Right(baseRequest.withEntity(Chunked(`application/pdf`, source(
Chunk(ByteString("abc")), Chunk(ByteString("abc")),
Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"),
Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), Chunk(ByteString("0123456789ABCDEF"), "foo=bar"),
Chunk(ByteString("0123456789ABCDEF"), "")))))) Chunk(ByteString("0123456789"), ""),
LastChunk)))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -203,14 +207,14 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"message end with extension, trailer and remaining content" in new Test { "message end with extension and trailer" in new Test {
Seq(start, Seq(start,
"""000;nice=true """000;nice=true
|Foo: pip |Foo: pip
| apo | apo
|Bar: xyz |Bar: xyz
| |
|GE""") should generalMultiParseTo( |""") should generalMultiParseTo(
Right(baseRequest.withEntity(Chunked(`application/pdf`, Right(baseRequest.withEntity(Chunked(`application/pdf`,
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
@ -238,13 +242,16 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|Content-Type: application/pdf |Content-Type: application/pdf
|Host: ping |Host: ping
| |
|0
|
|""" should parseTo(HttpRequest(PATCH, "/data", List(`Transfer-Encoding`(TransferEncodings.Extension("fancy")), |""" should parseTo(HttpRequest(PATCH, "/data", List(`Transfer-Encoding`(TransferEncodings.Extension("fancy")),
Host("ping")), HttpEntity.Chunked(`application/pdf`, source()))) Host("ping")), HttpEntity.Chunked(`application/pdf`, source(LastChunk))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"support `rawRequestUriHeader` setting" in new Test { "support `rawRequestUriHeader` setting" in new Test {
override protected def newParser: HttpRequestParser = new HttpRequestParser(parserSettings, rawRequestUriHeader = true)() override protected def newParser: HttpRequestParser =
new HttpRequestParser(parserSettings, rawRequestUriHeader = true, _headerParser = HttpHeaderParser(parserSettings)())
"""GET /f%6f%6fbar?q=b%61z HTTP/1.1 """GET /f%6f%6fbar?q=b%61z HTTP/1.1
|Host: ping |Host: ping
@ -275,19 +282,19 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
Seq(start, Seq(start,
"""15 ; """15 ;
|""") should generalMultiParseTo(Right(baseRequest), |""") should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("Illegal character ' ' in chunk start")))) Left(EntityStreamError(ErrorInfo("Illegal character ' ' in chunk start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"an illegal char in chunk size" in new Test { "an illegal char in chunk size" in new Test {
Seq(start, "bla") should generalMultiParseTo(Right(baseRequest), Seq(start, "bla") should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("Illegal character 'l' in chunk start")))) Left(EntityStreamError(ErrorInfo("Illegal character 'l' in chunk start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"too-long chunk extension" in new Test { "too-long chunk extension" in new Test {
Seq(start, "3;" + ("x" * 257)) should generalMultiParseTo(Right(baseRequest), Seq(start, "3;" + ("x" * 257)) should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters")))) Left(EntityStreamError(ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -295,7 +302,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
Seq(start, Seq(start,
"""1a2b3c4d5e """1a2b3c4d5e
|""") should generalMultiParseTo(Right(baseRequest), |""") should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes")))) Left(EntityStreamError(ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -303,7 +310,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
Seq(start, Seq(start,
"""3 """3
|abcde""") should generalMultiParseTo(Right(baseRequest), |abcde""") should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("Illegal chunk termination")))) Left(EntityStreamError(ErrorInfo("Illegal chunk termination"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -311,7 +318,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
Seq(start, Seq(start,
"""0 """0
|F@oo: pip""") should generalMultiParseTo(Right(baseRequest), |F@oo: pip""") should generalMultiParseTo(Right(baseRequest),
Left(ParseError(400: StatusCode, ErrorInfo("Illegal character '@' in header name")))) Left(EntityStreamError(ErrorInfo("Illegal character '@' in header name"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
} }
@ -333,7 +340,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
"two Content-Length headers" in new Test { "two Content-Length headers" in new Test {
"""GET / HTTP/1.1 """GET / HTTP/1.1
|Content-Length: 3 |Content-Length: 3
|Content-Length: 3 |Content-Length: 4
| |
|foo""" should parseToError(BadRequest, |foo""" should parseToError(BadRequest,
ErrorInfo("HTTP message must not contain more than one Content-Length header")) ErrorInfo("HTTP message must not contain more than one Content-Length header"))
@ -403,9 +410,8 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
override def toString = req.toString override def toString = req.toString
} }
def strictEqualify(x: Either[ParseError, HttpRequest]): Either[ParseError, StrictEqualHttpRequest] = { def strictEqualify[T](x: Either[T, HttpRequest]): Either[T, StrictEqualHttpRequest] =
x.right.map(new StrictEqualHttpRequest(_)) x.right.map(new StrictEqualHttpRequest(_))
}
def parseTo(expected: HttpRequest*): Matcher[String] = def parseTo(expected: HttpRequest*): Matcher[String] =
multiParseTo(expected: _*).compose(_ :: Nil) multiParseTo(expected: _*).compose(_ :: Nil)
@ -420,35 +426,35 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
generalRawMultiParseTo(parser, expected.map(Right(_)): _*) generalRawMultiParseTo(parser, expected.map(Right(_)): _*)
def parseToError(status: StatusCode, info: ErrorInfo): Matcher[String] = def parseToError(status: StatusCode, info: ErrorInfo): Matcher[String] =
generalMultiParseTo(Left(ParseError(status, info))).compose(_ :: Nil) generalMultiParseTo(Left(MessageStartError(status, info))).compose(_ :: Nil)
def generalMultiParseTo(expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = def generalMultiParseTo(expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] =
generalRawMultiParseTo(expected: _*).compose(_ map prep) generalRawMultiParseTo(expected: _*).compose(_ map prep)
def generalRawMultiParseTo(expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = def generalRawMultiParseTo(expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] =
generalRawMultiParseTo(newParser, expected: _*) generalRawMultiParseTo(newParser, expected: _*)
def generalRawMultiParseTo(parser: HttpRequestParser, def generalRawMultiParseTo(parser: HttpRequestParser,
expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = expected: Either[RequestOutput, HttpRequest]*): Matcher[Seq[String]] =
equal(expected.map(strictEqualify)) equal(expected.map(strictEqualify))
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser) .matcher[Seq[Either[RequestOutput, StrictEqualHttpRequest]]] compose multiParse(parser)
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] = def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] =
Source(input.toList) Source(input.toList)
.map(ByteString.apply) .map(ByteString.apply)
.transform("parser", () parser) .transform("parser", () parser)
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) .splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.headAndTail .headAndTail
.collect { .collect {
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) case (RequestStart(method, uri, protocol, headers, createEntity, _, close), entityParts)
closeAfterResponseCompletion :+= close closeAfterResponseCompletion :+= close
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
case (x: ParseError, _) Left(x) case (x @ (MessageStartError(_, _) | EntityStreamError(_)), _) Left(x)
} }
.map { x .map { x
Source { Source {
x match { x match {
case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x))) case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x)))
case Left(error) Future.successful(Left(error)) case Left(error) FastFuture.successful(Left(error))
} }
} }
} }
@ -458,7 +464,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
.awaitResult(awaitAtMost) .awaitResult(awaitAtMost)
protected def parserSettings: ParserSettings = ParserSettings(system) protected def parserSettings: ParserSettings = ParserSettings(system)
protected def newParser = new HttpRequestParser(parserSettings, false)() protected def newParser = new HttpRequestParser(parserSettings, false, HttpHeaderParser(parserSettings)())
private def compactEntity(entity: RequestEntity): Future[RequestEntity] = private def compactEntity(entity: RequestEntity): Future[RequestEntity] =
entity match { entity match {

View file

@ -22,7 +22,7 @@ import HttpMethods._
import HttpProtocols._ import HttpProtocols._
import StatusCodes._ import StatusCodes._
import HttpEntity._ import HttpEntity._
import ParserOutput.ParseError import ParserOutput._
import FastFuture._ import FastFuture._
class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
@ -43,7 +43,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
"a 200 response to a HEAD request" in new Test { "a 200 response to a HEAD request" in new Test {
"""HTTP/1.1 200 OK """HTTP/1.1 200 OK
| |
|HTT""" should parseTo(HEAD, HttpResponse()) |""" should parseTo(HEAD, HttpResponse())
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -97,14 +97,15 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
"a response with 3 headers, a body and remaining content" in new Test { "a response with 3 headers, a body and remaining content" in new Test {
"""HTTP/1.1 500 Internal Server Error Seq("""HTTP/1.1 500 Internal Server Error
|User-Agent: curl/7.19.7 xyz |User-Agent: curl/7.19.7 xyz
|Connection:close |Connection:close
|Content-Length: 17 |Content-Length: 17
|Content-Type: text/plain; charset=UTF-8 |Content-Type: text/plain; charset=UTF-8
| |
|Shake your BOODY!HTTP/1.""" should parseTo(HttpResponse(InternalServerError, List(Connection("close"), |Sh""", "ake your BOODY!HTTP/1.") should generalMultiParseTo(
`User-Agent`("curl/7.19.7 xyz")), "Shake your BOODY!")) Right(HttpResponse(InternalServerError, List(Connection("close"), `User-Agent`("curl/7.19.7 xyz")),
"Shake your BOODY!")))
closeAfterResponseCompletion shouldEqual Seq(true) closeAfterResponseCompletion shouldEqual Seq(true)
} }
@ -133,7 +134,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
"response start" in new Test { "response start" in new Test {
Seq(start, "rest") should generalMultiParseTo( Seq(start, "rest") should generalMultiParseTo(
Right(baseResponse.withEntity(Chunked(`application/pdf`, source()))), Right(baseResponse.withEntity(Chunked(`application/pdf`, source()))),
Left("Illegal character 'r' in chunk start")) Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -150,12 +151,14 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|10 |10
|0123456789""", |0123456789""",
"""ABCDEF """ABCDEF
|dead""") should generalMultiParseTo( |0
|
|""") should generalMultiParseTo(
Right(baseResponse.withEntity(Chunked(`application/pdf`, source( Right(baseResponse.withEntity(Chunked(`application/pdf`, source(
Chunk(ByteString("abc")), Chunk(ByteString("abc")),
Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"),
Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), Chunk(ByteString("0123456789ABCDEF"), "foo=bar"),
Chunk(ByteString("0123456789ABCDEF"), "")))))) Chunk(ByteString("0123456789ABCDEF")), LastChunk)))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -177,33 +180,38 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
| |
|HT""") should generalMultiParseTo( |HT""") should generalMultiParseTo(
Right(baseResponse.withEntity(Chunked(`application/pdf`, Right(baseResponse.withEntity(Chunked(`application/pdf`,
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))),
Left(MessageStartError(400: StatusCode, ErrorInfo("Illegal HTTP message start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"response with additional transfer encodings" in new Test { "response with additional transfer encodings" in new Test {
"""HTTP/1.1 200 OK Seq("""HTTP/1.1 200 OK
|Transfer-Encoding: fancy, chunked |Transfer-Encoding: fancy, chunked
|Content-Type: application/pdf |Cont""", """ent-Type: application/pdf
| |
|""" should parseTo(HttpResponse(headers = List(`Transfer-Encoding`(TransferEncodings.Extension("fancy"))), |""") should generalMultiParseTo(
entity = HttpEntity.Chunked(`application/pdf`, source()))) Right(HttpResponse(headers = List(`Transfer-Encoding`(TransferEncodings.Extension("fancy"))),
entity = HttpEntity.Chunked(`application/pdf`, source()))),
Left(EntityStreamError(ErrorInfo("Entity stream truncation"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
} }
"reject a response with" - { "reject a response with" - {
"HTTP version 1.2" in new Test { "HTTP version 1.2" in new Test {
Seq("HTTP/1.2 200 OK\r\n") should generalMultiParseTo(Left("The server-side HTTP version is not supported")) Seq("HTTP/1.2 200 OK\r\n") should generalMultiParseTo(Left(MessageStartError(
400: StatusCode, ErrorInfo("The server-side HTTP version is not supported"))))
} }
"an illegal status code" in new Test { "an illegal status code" in new Test {
Seq("HTTP/1", ".1 2000 Something") should generalMultiParseTo(Left("Illegal response status code")) Seq("HTTP/1", ".1 2000 Something") should generalMultiParseTo(Left(MessageStartError(
400: StatusCode, ErrorInfo("Illegal response status code"))))
} }
"a too-long response status reason" in new Test { "a too-long response status reason" in new Test {
Seq("HTTP/1.1 204 12345678", "90123456789012\r\n") should generalMultiParseTo( Seq("HTTP/1.1 204 12345678", "90123456789012\r\n") should generalMultiParseTo(Left(
Left("Response reason phrase exceeds the configured limit of 21 characters")) MessageStartError(400: StatusCode, ErrorInfo("Response reason phrase exceeds the configured limit of 21 characters"))))
} }
} }
} }
@ -224,9 +232,8 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
override def toString = resp.toString override def toString = resp.toString
} }
def strictEqualify(x: Either[String, HttpResponse]): Either[String, StrictEqualHttpResponse] = { def strictEqualify[T](x: Either[T, HttpResponse]): Either[T, StrictEqualHttpResponse] =
x.right.map(new StrictEqualHttpResponse(_)) x.right.map(new StrictEqualHttpResponse(_))
}
def parseTo(expected: HttpResponse*): Matcher[String] = parseTo(GET, expected: _*) def parseTo(expected: HttpResponse*): Matcher[String] = parseTo(GET, expected: _*)
def parseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[String] = def parseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[String] =
@ -240,46 +247,44 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def rawMultiParseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[Seq[String]] = def rawMultiParseTo(requestMethod: HttpMethod, expected: HttpResponse*): Matcher[Seq[String]] =
generalRawMultiParseTo(requestMethod, expected.map(Right(_)): _*) generalRawMultiParseTo(requestMethod, expected.map(Right(_)): _*)
def parseToError(error: String): Matcher[String] = generalMultiParseTo(Left(error)).compose(_ :: Nil) def parseToError(error: ResponseOutput): Matcher[String] = generalMultiParseTo(Left(error)).compose(_ :: Nil)
def generalMultiParseTo(expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = def generalMultiParseTo(expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] =
generalRawMultiParseTo(expected: _*).compose(_ map prep) generalRawMultiParseTo(expected: _*).compose(_ map prep)
def generalRawMultiParseTo(expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = def generalRawMultiParseTo(expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] =
generalRawMultiParseTo(GET, expected: _*) generalRawMultiParseTo(GET, expected: _*)
def generalRawMultiParseTo(requestMethod: HttpMethod, expected: Either[String, HttpResponse]*): Matcher[Seq[String]] = def generalRawMultiParseTo(requestMethod: HttpMethod, expected: Either[ResponseOutput, HttpResponse]*): Matcher[Seq[String]] =
equal(expected.map(strictEqualify)) equal(expected.map(strictEqualify))
.matcher[Seq[Either[String, StrictEqualHttpResponse]]] compose { .matcher[Seq[Either[ResponseOutput, StrictEqualHttpResponse]]] compose { input: Seq[String]
input: Seq[String] val future =
val future = Source(input.toList)
Source(input.toList) .map(ByteString.apply)
.map(ByteString.apply) .transform("parser", () newParser(requestMethod))
.transform("parser", () newParser(requestMethod)) .splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) .headAndTail
.headAndTail .collect {
.collect { case (ResponseStart(statusCode, protocol, headers, createEntity, close), entityParts)
case (ParserOutput.ResponseStart(statusCode, protocol, headers, createEntity, close), entityParts) closeAfterResponseCompletion :+= close
closeAfterResponseCompletion :+= close Right(HttpResponse(statusCode, headers, createEntity(entityParts), protocol))
Right(HttpResponse(statusCode, headers, createEntity(entityParts), protocol)) case (x @ (MessageStartError(_, _) | EntityStreamError(_)), _) Left(x)
case (x: ParseError, _) Left(x) }.map { x
}.map { x Source {
Source { x match {
x match { case Right(response) compactEntity(response.entity).fast.map(x Right(response.withEntity(x)))
case Right(response) compactEntity(response.entity).fast.map(x Right(response.withEntity(x))) case Left(error) FastFuture.successful(Left(error))
case Left(error) FastFuture.successful(Left(error.info.formatPretty))
}
} }
} }
.flatten(FlattenStrategy.concat) }
.map(strictEqualify) .flatten(FlattenStrategy.concat)
.grouped(1000).runWith(Sink.head) .map(strictEqualify)
Await.result(future, 500.millis) .grouped(1000).runWith(Sink.head)
Await.result(future, 500.millis)
} }
def parserSettings: ParserSettings = ParserSettings(system) def parserSettings: ParserSettings = ParserSettings(system)
def newParser(requestMethod: HttpMethod = GET) = { def newParser(requestMethod: HttpMethod = GET) = {
val parser = new HttpResponseParser(parserSettings, val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)(), () requestMethod)
dequeueRequestMethodForNextResponse = () requestMethod)()
parser parser
} }

View file

@ -4,7 +4,6 @@
package akka.http.engine.rendering package akka.http.engine.rendering
import akka.http.model.HttpMethods._
import com.typesafe.config.{ Config, ConfigFactory } import com.typesafe.config.{ Config, ConfigFactory }
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.concurrent.Await import scala.concurrent.Await
@ -18,7 +17,6 @@ import akka.http.util._
import akka.util.ByteString import akka.util.ByteString
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.FlowMaterializer import akka.stream.FlowMaterializer
import akka.stream.impl.SynchronousIterablePublisher
import HttpEntity._ import HttpEntity._
class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll { class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
@ -51,7 +49,6 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
|Age: 0 |Age: 0
|Server: akka-http/1.0.0 |Server: akka-http/1.0.0
|Date: Thu, 25 Aug 2011 09:10:29 GMT |Date: Thu, 25 Aug 2011 09:10:29 GMT
|Content-Length: 0
| |
|""" |"""
} }

View file

@ -5,23 +5,26 @@
package akka.http.engine.server package akka.http.engine.server
import scala.concurrent.duration._ import scala.concurrent.duration._
import org.scalatest.{ Inside, BeforeAndAfterAll, Matchers }
import akka.event.NoLogging import akka.event.NoLogging
import akka.http.model.HttpEntity.{ Chunk, ChunkStreamPart, LastChunk } import akka.util.ByteString
import akka.http.model._
import akka.http.model.headers.{ ProductVersion, Server, Host }
import akka.http.util._
import akka.http.Http
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.FlowMaterializer import akka.stream.FlowMaterializer
import akka.stream.io.StreamTcp import akka.stream.io.StreamTcp
import akka.stream.testkit.{ AkkaSpec, StreamTestKit } import akka.stream.testkit.{ AkkaSpec, StreamTestKit }
import akka.util.ByteString import akka.http.Http
import org.scalatest._ import akka.http.model._
import akka.http.util._
import headers._
import HttpEntity._
import MediaTypes._
import HttpMethods._
class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside { class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside {
implicit val materializer = FlowMaterializer() implicit val materializer = FlowMaterializer()
"The server implementation should" should { "The server implementation" should {
"deliver an empty request as soon as all headers are received" in new TestSetup { "deliver an empty request as soon as all headers are received" in new TestSetup {
send("""GET / HTTP/1.1 send("""GET / HTTP/1.1
|Host: example.com |Host: example.com
@ -30,6 +33,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com")))
} }
"deliver a request as soon as all headers are received" in new TestSetup { "deliver a request as soon as all headers are received" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
@ -38,7 +42,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString] val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -53,18 +57,26 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
dataProbe.expectNoMsg(50.millis) dataProbe.expectNoMsg(50.millis)
} }
} }
"deliver an error as soon as a parsing error occurred" in new TestSetup {
pending "deliver an error response as soon as a parsing error occurred" in new TestSetup {
// POST should require Content-Length header send("""GET / HTTP/1.2
send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
requests.expectError() netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 505 HTTP Version Not Supported
|Server: akka-http/test
|Date: XXXX
|Connection: close
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 74
|
|The server does not support the HTTP protocol version used in the request.""".stripMarginWithNewline("\r\n")
} }
"report a invalid Chunked stream" in new TestSetup {
pending "report an invalid Chunked stream" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
|Transfer-Encoding: chunked |Transfer-Encoding: chunked
@ -74,7 +86,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -83,8 +95,23 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
dataProbe.expectNoMsg(50.millis) dataProbe.expectNoMsg(50.millis)
send("3ghi\r\n") // missing "\r\n" after the number of bytes send("3ghi\r\n") // missing "\r\n" after the number of bytes
dataProbe.expectError() val error = dataProbe.expectError()
requests.expectError() error.getMessage shouldEqual "Illegal character 'g' in chunk start"
requests.expectComplete()
netOutSub.request(1)
responsesSub.expectRequest()
responsesSub.sendError(error.asInstanceOf[Exception])
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 400 Bad Request
|Server: akka-http/test
|Date: XXXX
|Connection: close
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 36
|
|Illegal character 'g' in chunk start""".stripMarginWithNewline("\r\n")
} }
} }
@ -97,11 +124,12 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
expectRequest shouldEqual expectRequest shouldEqual
HttpRequest( HttpRequest(
method = HttpMethods.POST, method = POST,
uri = "http://example.com/strict", uri = "http://example.com/strict",
headers = List(Host("example.com")), headers = List(Host("example.com")),
entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl")))
} }
"deliver the request entity as it comes in for a Default entity" in new TestSetup { "deliver the request entity as it comes in for a Default entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
@ -110,7 +138,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|abcdef""".stripMarginWithNewline("\r\n")) |abcdef""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString] val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -122,6 +150,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
dataProbe.expectNoMsg(50.millis) dataProbe.expectNoMsg(50.millis)
} }
} }
"deliver the request entity as it comes in for a chunked entity" in new TestSetup { "deliver the request entity as it comes in for a chunked entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
@ -132,7 +161,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -154,7 +183,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
expectRequest shouldEqual expectRequest shouldEqual
HttpRequest( HttpRequest(
method = HttpMethods.POST, method = POST,
uri = "http://example.com/strict", uri = "http://example.com/strict",
headers = List(Host("example.com")), headers = List(Host("example.com")),
entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl")))
@ -167,11 +196,12 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
expectRequest shouldEqual expectRequest shouldEqual
HttpRequest( HttpRequest(
method = HttpMethods.POST, method = POST,
uri = "http://example.com/next-strict", uri = "http://example.com/next-strict",
headers = List(Host("example.com")), headers = List(Host("example.com")),
entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("mnopqrstuvwx"))) entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("mnopqrstuvwx")))
} }
"deliver the second message properly after a Default entity" in new TestSetup { "deliver the second message properly after a Default entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
@ -180,7 +210,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|abcdef""".stripMarginWithNewline("\r\n")) |abcdef""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString] val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -202,10 +232,11 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|abcde""".stripMarginWithNewline("\r\n")) |abcde""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _)
data shouldEqual (ByteString("abcde")) data shouldEqual ByteString("abcde")
} }
} }
"deliver the second message properly after a Chunked entity" in new TestSetup { "deliver the second message properly after a Chunked entity" in new TestSetup {
send("""POST /chunked HTTP/1.1 send("""POST /chunked HTTP/1.1
|Host: example.com |Host: example.com
@ -216,7 +247,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -239,8 +270,8 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|abcde""".stripMarginWithNewline("\r\n")) |abcde""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _)
data shouldEqual (ByteString("abcde")) data shouldEqual ByteString("abcde")
} }
} }
@ -252,7 +283,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|abcdef""".stripMarginWithNewline("\r\n")) |abcdef""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString] val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -264,6 +295,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
dataProbe.expectComplete() dataProbe.expectComplete()
} }
} }
"close the request entity stream when the entity is complete for a Chunked entity" in new TestSetup { "close the request entity stream when the entity is complete for a Chunked entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
@ -274,7 +306,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
@ -288,27 +320,26 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
} }
} }
"report a truncated entity stream on the entity data stream and the main stream for a Default entity" in pendingUntilFixed(new TestSetup { "report a truncated entity stream on the entity data stream and the main stream for a Default entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
|Content-Length: 12 |Content-Length: 12
| |
|abcdef""".stripMarginWithNewline("\r\n")) |abcdef""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString] val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
sub.request(10) sub.request(10)
dataProbe.expectNext(ByteString("abcdef")) dataProbe.expectNext(ByteString("abcdef"))
dataProbe.expectNoMsg(50.millis) dataProbe.expectNoMsg(50.millis)
closeNetworkInput() closeNetworkInput()
dataProbe.expectError() dataProbe.expectError().getMessage shouldEqual "Entity stream truncation"
} }
}) }
"report a truncated entity stream on the entity data stream and the main stream for a Chunked entity" in pendingUntilFixed(new TestSetup {
"report a truncated entity stream on the entity data stream and the main stream for a Chunked entity" in new TestSetup {
send("""POST / HTTP/1.1 send("""POST / HTTP/1.1
|Host: example.com |Host: example.com
|Transfer-Encoding: chunked |Transfer-Encoding: chunked
@ -316,20 +347,18 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|6 |6
|abcdef |abcdef
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run() data.to(Sink(dataProbe)).run()
val sub = dataProbe.expectSubscription() val sub = dataProbe.expectSubscription()
sub.request(10) sub.request(10)
dataProbe.expectNext(Chunk(ByteString("abcdef"))) dataProbe.expectNext(Chunk(ByteString("abcdef")))
dataProbe.expectNoMsg(50.millis) dataProbe.expectNoMsg(50.millis)
closeNetworkInput() closeNetworkInput()
dataProbe.expectError() dataProbe.expectError().getMessage shouldEqual "Entity stream truncation"
} }
}) }
"translate HEAD request to GET request when transparent-head-requests are enabled" in new TestSetup { "translate HEAD request to GET request when transparent-head-requests are enabled" in new TestSetup {
override def settings = ServerSettings(system).copy(transparentHeadRequests = true) override def settings = ServerSettings(system).copy(transparentHeadRequests = true)
@ -337,8 +366,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
expectRequest shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com")))
expectRequest shouldEqual HttpRequest(HttpMethods.GET, uri = "http://example.com/", headers = List(Host("example.com")))
} }
"keep HEAD request when transparent-head-requests are disabled" in new TestSetup { "keep HEAD request when transparent-head-requests are disabled" in new TestSetup {
@ -347,21 +375,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
expectRequest shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com")))
expectRequest shouldEqual HttpRequest(HttpMethods.HEAD, uri = "http://example.com/", headers = List(Host("example.com")))
} }
"not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Strict)" in new TestSetup { "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Strict)" in new TestSetup {
override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
send("""HEAD / HTTP/1.1 send("""HEAD / HTTP/1.1
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.GET, _, _, _, _) case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd")))) responsesSub.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd"))))
netOutSub.request(1) netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual wipeDate(netOut.expectNext().utf8String) shouldEqual
"""|HTTP/1.1 200 OK """|HTTP/1.1 200 OK
@ -375,23 +399,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
} }
"not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Default)" in new TestSetup { "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Default)" in new TestSetup {
override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
send("""HEAD / HTTP/1.1 send("""HEAD / HTTP/1.1
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
val data = StreamTestKit.PublisherProbe[ByteString] val data = StreamTestKit.PublisherProbe[ByteString]
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.GET, _, _, _, _) case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data))))
netOutSub.request(1) netOutSub.request(1)
val dataSub = data.expectSubscription() val dataSub = data.expectSubscription()
dataSub.expectCancellation() dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual wipeDate(netOut.expectNext().utf8String) shouldEqual
"""|HTTP/1.1 200 OK """|HTTP/1.1 200 OK
|Server: akka-http/test |Server: akka-http/test
@ -404,23 +422,17 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
} }
"not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with CloseDelimited)" in new TestSetup { "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with CloseDelimited)" in new TestSetup {
override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
send("""HEAD / HTTP/1.1 send("""HEAD / HTTP/1.1
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
val data = StreamTestKit.PublisherProbe[ByteString] val data = StreamTestKit.PublisherProbe[ByteString]
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.GET, _, _, _, _) case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1) netOutSub.request(1)
val dataSub = data.expectSubscription() val dataSub = data.expectSubscription()
dataSub.expectCancellation() dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual wipeDate(netOut.expectNext().utf8String) shouldEqual
"""|HTTP/1.1 200 OK """|HTTP/1.1 200 OK
|Server: akka-http/test |Server: akka-http/test
@ -429,29 +441,22 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
| |
|""".stripMarginWithNewline("\r\n") |""".stripMarginWithNewline("\r\n")
} }
// No close should happen here since this was a HEAD request // No close should happen here since this was a HEAD request
netOut.expectNoMsg(50.millis) netOut.expectNoMsg(50.millis)
} }
"not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Chunked)" in new TestSetup { "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Chunked)" in new TestSetup {
override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
send("""HEAD / HTTP/1.1 send("""HEAD / HTTP/1.1
|Host: example.com |Host: example.com
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
val data = StreamTestKit.PublisherProbe[ChunkStreamPart] val data = StreamTestKit.PublisherProbe[ChunkStreamPart]
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.GET, _, _, _, _) case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1) netOutSub.request(1)
val dataSub = data.expectSubscription() val dataSub = data.expectSubscription()
dataSub.expectCancellation() dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual wipeDate(netOut.expectNext().utf8String) shouldEqual
"""|HTTP/1.1 200 OK """|HTTP/1.1 200 OK
|Server: akka-http/test |Server: akka-http/test
@ -464,29 +469,146 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
} }
"respect Connection headers of HEAD requests if transparent-head-requests is enabled" in new TestSetup { "respect Connection headers of HEAD requests if transparent-head-requests is enabled" in new TestSetup {
override def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
send("""HEAD / HTTP/1.1 send("""HEAD / HTTP/1.1
|Host: example.com |Host: example.com
|Connection: close |Connection: close
| |
|""".stripMarginWithNewline("\r\n")) |""".stripMarginWithNewline("\r\n"))
val data = StreamTestKit.PublisherProbe[ByteString] val data = StreamTestKit.PublisherProbe[ByteString]
inside(expectRequest) { inside(expectRequest) {
case HttpRequest(HttpMethods.GET, _, _, _, _) case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1) netOutSub.request(1)
val dataSub = data.expectSubscription() val dataSub = data.expectSubscription()
dataSub.expectCancellation() dataSub.expectCancellation()
netOut.expectNext() netOut.expectNext()
} }
netOut.expectComplete() netOut.expectComplete()
} }
"produce a `100 Continue` response when requested by a `Default` entity" in new TestSetup {
send("""POST / HTTP/1.1
|Host: example.com
|Expect: 100-continue
|Content-Length: 16
|
|""".stripMarginWithNewline("\r\n"))
inside(expectRequest) {
case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ByteString]
data.to(Sink(dataProbe)).run()
val dataSub = dataProbe.expectSubscription()
netOutSub.request(2)
netOut.expectNoMsg(50.millis)
dataSub.request(1) // triggers `100 Continue` response
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 100 Continue
|Server: akka-http/test
|Date: XXXX
|
|""".stripMarginWithNewline("\r\n")
dataProbe.expectNoMsg(50.millis)
send("0123456789ABCDEF")
dataProbe.expectNext(ByteString("0123456789ABCDEF"))
dataProbe.expectComplete()
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
}
}
"produce a `100 Continue` response when requested by a `Chunked` entity" in new TestSetup {
send("""POST / HTTP/1.1
|Host: example.com
|Expect: 100-continue
|Transfer-Encoding: chunked
|
|""".stripMarginWithNewline("\r\n"))
inside(expectRequest) {
case HttpRequest(POST, _, _, Chunked(ContentType(`application/octet-stream`, None), data), _)
val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run()
val dataSub = dataProbe.expectSubscription()
netOutSub.request(2)
netOut.expectNoMsg(50.millis)
dataSub.request(2) // triggers `100 Continue` response
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 100 Continue
|Server: akka-http/test
|Date: XXXX
|
|""".stripMarginWithNewline("\r\n")
dataProbe.expectNoMsg(50.millis)
send("""10
|0123456789ABCDEF
|0
|
|""".stripMarginWithNewline("\r\n"))
dataProbe.expectNext(Chunk(ByteString("0123456789ABCDEF")))
dataProbe.expectNext(LastChunk)
dataProbe.expectComplete()
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
}
}
"render a closing response instead of `100 Continue` if request entity is not requested" in new TestSetup {
send("""POST / HTTP/1.1
|Host: example.com
|Expect: 100-continue
|Content-Length: 16
|
|""".stripMarginWithNewline("\r\n"))
inside(expectRequest) {
case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _)
netOutSub.request(1)
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Connection: close
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
}
}
"render a 500 response on response stream errors from the application" in new TestSetup {
send("""GET / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com")))
netOutSub.request(1)
responsesSub.expectRequest()
responsesSub.sendError(new RuntimeException("CRASH BOOM BANG"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
"""HTTP/1.1 500 Internal Server Error
|Server: akka-http/test
|Date: XXXX
|Connection: close
|Content-Length: 0
|
|""".stripMarginWithNewline("\r\n")
}
} }
class TestSetup { class TestSetup {
@ -494,7 +616,7 @@ class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterA
val netOut = StreamTestKit.SubscriberProbe[ByteString] val netOut = StreamTestKit.SubscriberProbe[ByteString]
val tcpConnection = StreamTcp.IncomingTcpConnection(null, netIn, netOut) val tcpConnection = StreamTcp.IncomingTcpConnection(null, netIn, netOut)
def settings = ServerSettings(system) def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
val pipeline = new HttpServerPipeline(settings, NoLogging) val pipeline = new HttpServerPipeline(settings, NoLogging)
val Http.IncomingConnection(_, requestsIn, responsesOut) = pipeline(tcpConnection) val Http.IncomingConnection(_, requestsIn, responsesOut) = pipeline(tcpConnection)