!hco prepare for multipart body parser implementation

Breakingness stems from
- renaming `EnhancedString::getAsciiBytes` to `EnhancedString::asciiBytes`
- removal of `HttpForm` (for now)
- refactoring of `MultipartFormData`
This commit is contained in:
Mathias 2014-07-19 00:06:53 +02:00
parent 52839bc11d
commit bdcccf925a
12 changed files with 106 additions and 96 deletions

View file

@ -0,0 +1,17 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model
/**
* Model for `application/x-www-form-urlencoded` form data.
*/
final case class FormData(fields: Uri.Query) {
type FieldType = (String, String)
}
object FormData {
val Empty = FormData(Uri.Query.Empty)
def apply(fields: Map[String, String]): FormData = this(Uri.Query(fields))
}

View file

@ -1,53 +0,0 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model
import scala.collection.immutable
import headers._
sealed trait HttpForm {
type FieldType
def fields: Seq[FieldType]
}
/**
* Model for `application/x-www-form-urlencoded` form data.
*/
final case class FormData(fields: Uri.Query) extends HttpForm {
type FieldType = (String, String)
}
object FormData {
val Empty = FormData(Uri.Query.Empty)
def apply(fields: Map[String, String]): FormData = this(Uri.Query(fields))
}
/**
* Model for `multipart/form-data` content as defined in RFC 2388.
* All parts must contain a Content-Disposition header with a type form-data
* and a name parameter that is unique
*/
final case class MultipartFormData(fields: immutable.Seq[BodyPart]) extends HttpForm {
type FieldType = BodyPart
def get(partName: String): Option[BodyPart] = fields.find(_.name.exists(_ == partName))
}
object MultipartFormData {
val Empty = MultipartFormData()
def apply(fields: BodyPart*): MultipartFormData = apply(immutable.Seq(fields: _*))
def apply(fields: Map[String, BodyPart]): MultipartFormData = apply {
fields.map {
case (key, value) value.copy(headers = `Content-Disposition`(ContentDispositionTypes.`form-data`, Map("name" -> key)) +: value.headers)
}(collection.breakOut): _*
}
}
final case class FormFile(name: Option[String], entity: HttpEntity.Default)
object FormFile {
def apply(name: String, entity: HttpEntity.Default): FormFile = apply(Some(name), entity)
}

View file

@ -41,6 +41,33 @@ object MultipartByteRanges {
def apply(parts: BodyPart*): MultipartByteRanges = apply(SynchronousPublisherFromIterable[BodyPart](parts.toList))
}
/**
* Model for `multipart/form-data` content as defined in RFC 2388.
* All parts must contain a Content-Disposition header with a type form-data
* and a name parameter that is unique.
*/
final case class MultipartFormData(parts: Producer[BodyPart]) extends MultipartParts {
// def get(partName: String): Option[BodyPart] = fields.find(_.name.exists(_ == partName))
}
object MultipartFormData {
val Empty = MultipartFormData()
def apply(parts: BodyPart*): MultipartFormData = apply(SynchronousProducerFromIterable[BodyPart](parts.toList))
def apply(fields: Map[String, BodyPart]): MultipartFormData = apply {
fields.map {
case (key, value) value.copy(headers = `Content-Disposition`(ContentDispositionTypes.`form-data`, Map("name" -> key)) +: value.headers)
}(collection.breakOut): _*
}
}
final case class FormFile(name: Option[String], entity: HttpEntity.Default)
object FormFile {
def apply(name: String, entity: HttpEntity.Default): FormFile = apply(Some(name), entity)
}
/**
* Model for one part of a multipart message.
*/

View file

@ -22,7 +22,7 @@ import ProtectedHeaderCreation.enable
sealed abstract class ModeledCompanion extends Renderable {
val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $
val lowercaseName = name.toLowerCase
private[this] val nameBytes = name.getAsciiBytes
private[this] val nameBytes = name.asciiBytes
def render[R <: Rendering](r: R): r.type = r ~~ nameBytes ~~ ':' ~~ ' '
}

View file

@ -58,7 +58,7 @@ import akka.http.model.parser.CharacterClasses._
* cannot hold more then 255 items, so this array has a fixed size of 255.
*/
private[parsing] final class HttpHeaderParser private (
val settings: ParserSettings,
val settings: HttpHeaderParser.Settings,
warnOnIllegalHeader: ErrorInfo Unit,
private[this] var nodes: Array[Char] = new Array(512), // initial size, can grow as needed
private[this] var nodeCount: Int = 0,
@ -141,7 +141,7 @@ private[parsing] final class HttpHeaderParser private (
val colonIx = scanHeaderNameAndReturnIndexOfColon(input, lineStart, lineStart + maxHeaderNameLength)(cursor)
val headerName = asciiString(input, lineStart, colonIx)
try {
val valueParser = new RawHeaderValueParser(headerName, maxHeaderValueLength, settings.headerValueCacheLimit(headerName))
val valueParser = new RawHeaderValueParser(headerName, maxHeaderValueLength, headerValueCacheLimit(headerName))
insert(input, valueParser)(cursor, colonIx + 1, nodeIx, colonIx)
parseHeaderLine(input, lineStart)(cursor, nodeIx)
} catch {
@ -373,9 +373,16 @@ private[parsing] final class HttpHeaderParser private (
/**
* INTERNAL API
*/
private object HttpHeaderParser {
private[http] object HttpHeaderParser {
import SpecializedHeaderValueParsers._
trait Settings {
def maxHeaderNameLength: Int
def maxHeaderValueLength: Int
def maxHeaderCount: Int
def headerValueCacheLimit(headerName: String): Int
}
object EmptyHeader extends HttpHeader {
def name = ""
def lowercaseName = ""
@ -397,10 +404,10 @@ private object HttpHeaderParser {
private val defaultIllegalHeaderWarning: ErrorInfo Unit = info throw new IllegalHeaderException(info)
def apply(settings: ParserSettings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) =
def apply(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) =
prime(unprimed(settings, warnOnIllegalHeader))
def unprimed(settings: ParserSettings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) =
def unprimed(settings: HttpHeaderParser.Settings, warnOnIllegalHeader: ErrorInfo Unit = defaultIllegalHeaderWarning) =
new HttpHeaderParser(settings, warnOnIllegalHeader)
def prime(parser: HttpHeaderParser): HttpHeaderParser = {

View file

@ -23,6 +23,7 @@ import akka.stream.scaladsl.Flow
private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOutput <: ParserOutput](val settings: ParserSettings,
val headerParser: HttpHeaderParser)
extends Transformer[ByteString, Output] {
import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
@ -100,10 +101,10 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
case h: `Transfer-Encoding`
parseHeaderLines(input, lineEnd, headers, headerCount + 1, ch, clh, cth, Some(h), hh)
case h if headerCount < settings.maxHeaderCount
case h if headerCount < maxHeaderCount
parseHeaderLines(input, lineEnd, h :: headers, headerCount + 1, ch, clh, cth, teh, hh || h.isInstanceOf[Host])
case _ fail(s"HTTP message contains more than the configured limit of ${settings.maxHeaderCount} headers")
case _ fail(s"HTTP message contains more than the configured limit of $maxHeaderCount headers")
}
}
@ -144,9 +145,9 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
emit(ParserOutput.EntityChunk(lastChunk))
if (isLastMessage) terminate()
else startNewMessage(input, lineEnd)
case header if headerCount < settings.maxHeaderCount
case header if headerCount < maxHeaderCount
parseTrailer(extension, lineEnd, header :: headers, headerCount + 1)
case _ fail(s"Chunk trailer contains more than the configured limit of ${settings.maxHeaderCount} headers")
case _ fail(s"Chunk trailer contains more than the configured limit of $maxHeaderCount headers")
}
}
@ -165,24 +166,24 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
} else parseTrailer(extension, cursor)
@tailrec def parseChunkExtensions(chunkSize: Int, cursor: Int)(startIx: Int = cursor): StateResult =
if (cursor - startIx <= settings.maxChunkExtLength) {
if (cursor - startIx <= maxChunkExtLength) {
def extension = asciiString(input, startIx, cursor)
byteChar(input, cursor) match {
case '\r' if byteChar(input, cursor + 1) == '\n' parseChunkBody(chunkSize, extension, cursor + 2)
case '\n' parseChunkBody(chunkSize, extension, cursor + 1)
case _ parseChunkExtensions(chunkSize, cursor + 1)(startIx)
}
} else fail(s"HTTP chunk extension length exceeds configured limit of ${settings.maxChunkExtLength} characters")
} else fail(s"HTTP chunk extension length exceeds configured limit of $maxChunkExtLength characters")
@tailrec def parseSize(cursor: Int, size: Long): StateResult =
if (size <= settings.maxChunkSize) {
if (size <= maxChunkSize) {
byteChar(input, cursor) match {
case c if CharacterClasses.HEXDIG(c) parseSize(cursor + 1, size * 16 + CharUtils.hexValue(c))
case ';' if cursor > offset parseChunkExtensions(size.toInt, cursor + 1)()
case '\r' if cursor > offset && byteChar(input, cursor + 1) == '\n' parseChunkBody(size.toInt, "", cursor + 2)
case c fail(s"Illegal character '${escape(c)}' in chunk start")
}
} else fail(s"HTTP chunk size exceeds the configured limit of ${settings.maxChunkSize} bytes")
} else fail(s"HTTP chunk size exceeds the configured limit of $maxChunkSize bytes")
try parseSize(offset, 0)
catch {

View file

@ -23,6 +23,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
rawRequestUriHeader: Boolean,
materializer: FlowMaterializer)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings))
extends HttpMessageParser[ParserOutput.RequestOutput](_settings, _headerParser) {
import settings._
private[this] var method: HttpMethod = _
private[this] var uri: Uri = _
@ -84,19 +85,19 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
def parseRequestTarget(input: ByteString, cursor: Int): Int = {
val uriStart = cursor
val uriEndLimit = cursor + settings.maxUriLength
val uriEndLimit = cursor + maxUriLength
@tailrec def findUriEnd(ix: Int = cursor): Int =
if (ix == input.length) throw NotEnoughDataException
else if (CharacterClasses.WSPCRLF(input(ix).toChar)) ix
else if (ix < uriEndLimit) findUriEnd(ix + 1)
else throw new ParsingException(RequestUriTooLong,
s"URI length exceeds the configured limit of ${settings.maxUriLength} characters")
s"URI length exceeds the configured limit of $maxUriLength characters")
val uriEnd = findUriEnd()
try {
uriBytes = input.iterator.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here?
uri = Uri.parseHttpRequestTarget(uriBytes, mode = settings.uriParsingMode)
uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode)
} catch {
case e: IllegalUriException throw new ParsingException(BadRequest, e.info)
}
@ -119,9 +120,9 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
case Some(`Content-Length`(len)) len
case None 0
}
if (contentLength > settings.maxContentLength)
if (contentLength > maxContentLength)
fail(RequestEntityTooLarge,
s"Request Content-Length $contentLength exceeds the configured limit of $settings.maxContentLength")
s"Request Content-Length $contentLength exceeds the configured limit of $maxContentLength")
else if (contentLength == 0) {
emitRequestStart(emptyEntity(cth))
startNewMessage(input, bodyStart)

View file

@ -21,6 +21,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
materializer: FlowMaterializer,
dequeueRequestMethodForNextResponse: () HttpMethod = () NoMethod)(_headerParser: HttpHeaderParser = HttpHeaderParser(_settings))
extends HttpMessageParser[ParserOutput.ResponseOutput](_settings, _headerParser) {
import settings._
private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod
private[this] var statusCode: StatusCode = StatusCodes.OK
@ -65,11 +66,11 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
}
@tailrec private def parseReason(input: ByteString, startIx: Int)(cursor: Int = startIx): Int =
if (cursor - startIx <= settings.maxResponseReasonLength)
if (cursor - startIx <= maxResponseReasonLength)
if (byteChar(input, cursor) == '\r' && byteChar(input, cursor + 1) == '\n') cursor + 2
else parseReason(input, startIx)(cursor + 1)
else throw new ParsingException("Response reason phrase exceeds the configured limit of " +
settings.maxResponseReasonLength + " characters")
maxResponseReasonLength + " characters")
// http://tools.ietf.org/html/rfc7230#section-3.3
def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int,
@ -86,8 +87,8 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
teh match {
case None clh match {
case Some(`Content-Length`(contentLength))
if (contentLength > settings.maxContentLength)
fail(s"Response Content-Length $contentLength exceeds the configured limit of ${settings.maxContentLength}")
if (contentLength > maxContentLength)
fail(s"Response Content-Length $contentLength exceeds the configured limit of $maxContentLength")
else if (contentLength == 0) finishEmptyResponse()
else if (contentLength < input.size - bodyStart) {
val cl = contentLength.toInt
@ -118,11 +119,9 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
// currently we do not check for `settings.maxContentLength` overflow
def parseToCloseBody(input: ByteString, bodyStart: Int): StateResult = {
val remainingInputBytes = input.length - bodyStart
if (remainingInputBytes > 0) {
if (input.length > bodyStart)
emit(ParserOutput.EntityPart(input drop bodyStart))
continue(parseToCloseBody)
} else continue(input, bodyStart)(parseToCloseBody)
continue(parseToCloseBody)
}
}

View file

@ -20,7 +20,7 @@ final case class ParserSettings(
maxChunkSize: Int,
uriParsingMode: Uri.ParsingMode,
illegalHeaderWarnings: Boolean,
headerValueCacheLimits: Map[String, Int]) {
headerValueCacheLimits: Map[String, Int]) extends HttpHeaderParser.Settings {
require(maxUriLength > 0, "max-uri-length must be > 0")
require(maxResponseReasonLength > 0, "max-response-reason-length must be > 0")
@ -33,7 +33,7 @@ final case class ParserSettings(
val defaultHeaderValueCacheLimit: Int = headerValueCacheLimits("default")
def headerValueCacheLimit(headerName: String) =
def headerValueCacheLimit(headerName: String): Int =
headerValueCacheLimits.getOrElse(headerName, defaultHeaderValueCacheLimit)
}

View file

@ -18,11 +18,11 @@ import akka.http.util._
* INTERNAL API
*/
private object RenderSupport {
val DefaultStatusLineBytes = "HTTP/1.1 200 OK\r\n".getAsciiBytes
val StatusLineStartBytes = "HTTP/1.1 ".getAsciiBytes
val ChunkedBytes = "chunked".getAsciiBytes
val KeepAliveBytes = "Keep-Alive".getAsciiBytes
val CloseBytes = "close".getAsciiBytes
val DefaultStatusLineBytes = "HTTP/1.1 200 OK\r\n".asciiBytes
val StatusLineStartBytes = "HTTP/1.1 ".asciiBytes
val ChunkedBytes = "chunked".asciiBytes
val KeepAliveBytes = "Keep-Alive".asciiBytes
val CloseBytes = "close".asciiBytes
def CrLf = Rendering.CrLf

View file

@ -66,13 +66,24 @@ private[http] class EnhancedString(val underlying: String) extends AnyVal {
/**
* Returns the ASCII encoded bytes of this string. Truncates characters to 8-bit byte value.
*/
def getAsciiBytes: Array[Byte] = {
@tailrec def bytes(array: Array[Byte] = new Array[Byte](underlying.length), ix: Int = 0): Array[Byte] =
def asciiBytes: Array[Byte] = {
val array = new Array[Byte](underlying.length)
getAsciiBytes(array, 0)
array
}
/**
* Copies the ASCII encoded bytes of this string into the given byte array starting at the `offset` index.
* Truncates characters to 8-bit byte value.
* If the array does not have enough space for the whole string only the portion that fits is copied.
*/
def getAsciiBytes(array: Array[Byte], offset: Int): Unit = {
@tailrec def rec(ix: Int): Unit =
if (ix < array.length) {
array(ix) = underlying.charAt(ix).asInstanceOf[Byte]
bytes(array, ix + 1)
} else array
bytes()
array(ix) = underlying.charAt(ix - offset).asInstanceOf[Byte]
rec(ix + 1)
}
rec(offset)
}
/**
@ -86,7 +97,7 @@ private[http] class EnhancedString(val underlying: String) extends AnyVal {
* @see [[http://rdist.root.org/2009/05/28/timing-attack-in-google-keyczar-library/]]
* @see [[http://emerose.com/timing-attacks-explained]]
*/
def secure_==(other: String): Boolean = getAsciiBytes secure_== other.getAsciiBytes
def secure_==(other: String): Boolean = asciiBytes secure_== other.asciiBytes
/**
* Determines whether the underlying String starts with the given character.

View file

@ -52,7 +52,7 @@ private[http] trait LazyValueBytesRenderable extends Renderable {
// that a synchronization overhead or even @volatile reads
private[this] var _valueBytes: Array[Byte] = _
private def valueBytes =
if (_valueBytes != null) _valueBytes else { _valueBytes = value.getAsciiBytes; _valueBytes }
if (_valueBytes != null) _valueBytes else { _valueBytes = value.asciiBytes; _valueBytes }
def value: String
def render[R <: Rendering](r: R): r.type = r ~~ valueBytes
@ -66,7 +66,7 @@ private[http] trait LazyValueBytesRenderable extends Renderable {
* Useful for common predefined singleton values.
*/
private[http] trait SingletonValueRenderable extends Product with Renderable {
private[this] val valueBytes = value.getAsciiBytes
private[this] val valueBytes = value.asciiBytes
def value = productPrefix
def render[R <: Rendering](r: R): r.type = r ~~ valueBytes
}