=htc #18121 enable UTF-8 decoding for HTTP header values

If we detect a potential UTF-8 byte sequence we attempt decoding, otherwise we leave the bytes as they are.
This commit is contained in:
Mathias 2015-10-06 17:28:15 +02:00
parent 13ee16c72b
commit 3d3fe70293
4 changed files with 84 additions and 19 deletions

View file

@ -4,6 +4,7 @@
package akka.http.impl.engine.parsing
import java.nio.{ CharBuffer, ByteBuffer }
import java.util.Arrays.copyOf
import java.lang.{ StringBuilder JStringBuilder }
import scala.annotation.tailrec
@ -98,7 +99,7 @@ private[engine] final class HttpHeaderParser private (
@tailrec
def parseHeaderLine(input: ByteString, lineStart: Int = 0)(cursor: Int = lineStart, nodeIx: Int = 0): Int = {
def startValueBranch(rootValueIx: Int, valueParser: HeaderValueParser) = {
val (header, endIx) = valueParser(input, cursor, onIllegalHeader)
val (header, endIx) = valueParser(this, input, cursor, onIllegalHeader)
if (valueParser.cachingEnabled)
try {
val valueIx = newValueIndex // compute early in order to trigger OutOfTrieSpaceExceptions before any change
@ -149,7 +150,7 @@ private[engine] final class HttpHeaderParser private (
parseHeaderLine(input, lineStart)(cursor, nodeIx)
} catch {
case OutOfTrieSpaceException // if we cannot insert we drop back to simply creating new header instances
val (headerValue, endIx) = scanHeaderValue(input, colonIx + 1, colonIx + maxHeaderValueLength + 3)()
val (headerValue, endIx) = scanHeaderValue(this, input, colonIx + 1, colonIx + maxHeaderValueLength + 3)()
resultHeader = RawHeader(headerName, headerValue.trim)
endIx
}
@ -158,7 +159,7 @@ private[engine] final class HttpHeaderParser private (
@tailrec
private def parseHeaderValue(input: ByteString, valueStart: Int, branch: ValueBranch)(cursor: Int = valueStart, nodeIx: Int = branch.branchRootNodeIx): Int = {
def parseAndInsertHeader() = {
val (header, endIx) = branch.parser(input, valueStart, onIllegalHeader)
val (header, endIx) = branch.parser(this, input, valueStart, onIllegalHeader)
if (branch.spaceLeft)
try {
insert(input, header)(cursor, endIx, nodeIx, colonIx = 0)
@ -375,6 +376,30 @@ private[engine] final class HttpHeaderParser private (
* Returns a string representation of the trie structure size.
*/
def formatSizes: String = s"$nodeCount nodes, ${branchDataCount / 3} branchData rows, $valueCount values"
// helpers for UTF-8 decoding,
// since they are only accessed when an UTF8 byte sequence is actually hit and UTF-8 sequences in header values are
// rare these fields can be lazy, the overhead of the lazy access should be overcompensated for by the saved
// allocations in the majority of cases
private lazy val byteBuffer = ByteBuffer.allocate(4)
private lazy val charBuffer = CharBuffer.allocate(2)
private lazy val decoder = UTF8.newDecoder()
// returns the decoded character as a simple 16-bit Char value or a 32-bit surrogate pair
// or -1 if the byteBuffer bytes are not a complete and legal UTF-8 byte sequence
private def decodeByteBuffer(): Int = {
byteBuffer.flip()
val coderResult = decoder.decode(byteBuffer, charBuffer, false)
charBuffer.flip()
val result =
if (coderResult.isUnderflow & charBuffer.hasRemaining) {
val c = charBuffer.get()
if (charBuffer.hasRemaining) (charBuffer.get() << 16) | c else c
} else -1
byteBuffer.clear()
charBuffer.clear()
result
}
}
/**
@ -446,17 +471,17 @@ private[http] object HttpHeaderParser {
def insertRemainingCharsAsNewNodes(parser: HttpHeaderParser, input: ByteString, value: AnyRef): Unit =
parser.insertRemainingCharsAsNewNodes(input, value)()
abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) {
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int)
private[parsing] abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) {
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int)
override def toString: String = s"HeaderValueParser[$headerName]"
def cachingEnabled = maxValueCount > 0
}
class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings)
private[parsing] class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings)
extends HeaderValueParser(headerName, maxValueCount) {
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
// TODO: optimize by running the header value parser directly on the input ByteString (rather than an extracted String)
val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)()
val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)()
val trimmedHeaderValue = headerValue.trim
val header = HeaderParser.parseFull(headerName, trimmedHeaderValue, settings) match {
case Right(h) h
@ -468,10 +493,10 @@ private[http] object HttpHeaderParser {
}
}
class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int)
private[parsing] class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int)
extends HeaderValueParser(headerName, maxValueCount) {
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)()
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)()
RawHeader(headerName, headerValue.trim) -> endIx
}
}
@ -485,16 +510,49 @@ private[http] object HttpHeaderParser {
}
else fail(s"HTTP header name exceeds the configured limit of ${limit - start - 1} characters")
@tailrec private def scanHeaderValue(input: ByteString, start: Int, limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = {
def spaceAppended = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(' ')
@tailrec private def scanHeaderValue(hhp: HttpHeaderParser, input: ByteString, start: Int,
limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = {
def appended(c: Char) = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(c)
def appended2(c: Int) = if ((c >> 16) != 0) appended(c.toChar).append((c >> 16).toChar) else appended(c.toChar)
if (ix < limit)
byteChar(input, ix) match {
case '\t' scanHeaderValue(input, start, limit)(spaceAppended, ix + 1)
case '\t' scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 1)
case '\r' if byteChar(input, ix + 1) == '\n'
if (WSP(byteChar(input, ix + 2))) scanHeaderValue(input, start, limit)(spaceAppended, ix + 3)
if (WSP(byteChar(input, ix + 2))) scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 3)
else (if (sb != null) sb.toString else asciiString(input, start, ix), ix + 2)
case c if c >= ' ' scanHeaderValue(input, start, limit)(if (sb != null) sb.append(c) else sb, ix + 1)
case c fail(s"Illegal character '${escape(c)}' in header value")
case c
var nix = ix + 1
val nsb =
if (' ' <= c && c <= '\u007F') if (sb != null) sb.append(c) else null // legal 7-Bit ASCII
else if ((c & 0xE0) == 0xC0) { // 2-byte UTF-8 sequence?
hhp.byteBuffer.put(c.toByte)
hhp.byteBuffer.put(byteAt(input, ix + 1))
nix = ix + 2
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
case -1 if (sb != null) sb.append(c).append(byteChar(input, ix + 1)) else null
case cc appended2(cc)
}
} else if ((c & 0xF0) == 0xE0) { // 3-byte UTF-8 sequence?
hhp.byteBuffer.put(c.toByte)
hhp.byteBuffer.put(byteAt(input, ix + 1))
hhp.byteBuffer.put(byteAt(input, ix + 2))
nix = ix + 3
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
case -1 if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)) else null
case cc appended2(cc)
}
} else if ((c & 0xF8) == 0xF0) { // 4-byte UTF-8 sequence?
hhp.byteBuffer.put(c.toByte)
hhp.byteBuffer.put(byteAt(input, ix + 1))
hhp.byteBuffer.put(byteAt(input, ix + 2))
hhp.byteBuffer.put(byteAt(input, ix + 3))
nix = ix + 4
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
case -1 if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)).append(byteChar(input, ix + 3)) else null
case cc appended2(cc)
}
} else fail(s"Illegal character '${escape(c)}' in header value")
scanHeaderValue(hhp, input, start, limit)(nsb, nix)
}
else fail(s"HTTP header value exceeds the configured limit of ${limit - start - 2} characters")
}

View file

@ -19,7 +19,7 @@ private object SpecializedHeaderValueParsers {
def specializedHeaderValueParsers = Seq(ContentLengthParser)
object ContentLengthParser extends HeaderValueParser("Content-Length", maxValueCount = 1) {
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo Unit): (HttpHeader, Int) = {
@tailrec def recurse(ix: Int = valueStart, result: Long = 0): (HttpHeader, Int) = {
val c = byteChar(input, ix)
if (result < 0) fail("`Content-Length` header value must not exceed 63-bit integer range")

View file

@ -30,7 +30,7 @@ class ContentLengthHeaderParserSpec extends WordSpec with Matchers {
}
def parse(bigint: String): Long = {
val (`Content-Length`(length), _) = ContentLengthParser(ByteString(bigint + "\r\n").compact, 0, _ ())
val (`Content-Length`(length), _) = ContentLengthParser(null, ByteString(bigint + "\r\n").compact, 0, _ ())
length
}

View file

@ -150,6 +150,13 @@ class HttpHeaderParserSpec extends WordSpec with Matchers with BeforeAndAfterAll
parseAndCache("Fancy: foo\tbar\r\nx")() shouldEqual RawHeader("Fancy", "foo bar")
}
"parse and cache a header with UTF8 chars in the value" in new TestSetup() {
parseAndCache("2-UTF8-Bytes: árvíztűrő ütvefúrógép\r\nx")() shouldEqual RawHeader("2-UTF8-Bytes", "árvíztűrő ütvefúrógép")
parseAndCache("3-UTF8-Bytes: The € or the $?\r\nx")() shouldEqual RawHeader("3-UTF8-Bytes", "The € or the $?")
parseAndCache("4-UTF8-Bytes: Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!\r\nx")() shouldEqual
RawHeader("4-UTF8-Bytes", "Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!")
}
"produce an error message for lines with an illegal header name" in new TestSetup() {
the[ParsingException] thrownBy parseLine(" Connection: close\r\nx") should have message "Illegal character ' ' in header name"
the[ParsingException] thrownBy parseLine("Connection : close\r\nx") should have message "Illegal character ' ' in header name"