=htc #18121 enable UTF-8 decoding for HTTP header values
If we detect a potential UTF-8 byte sequence we attempt decoding, otherwise we leave the bytes as they are.
This commit is contained in:
parent
13ee16c72b
commit
3d3fe70293
4 changed files with 84 additions and 19 deletions
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
package akka.http.impl.engine.parsing
|
||||
|
||||
import java.nio.{ CharBuffer, ByteBuffer }
|
||||
import java.util.Arrays.copyOf
|
||||
import java.lang.{ StringBuilder ⇒ JStringBuilder }
|
||||
import scala.annotation.tailrec
|
||||
|
|
@ -98,7 +99,7 @@ private[engine] final class HttpHeaderParser private (
|
|||
@tailrec
|
||||
def parseHeaderLine(input: ByteString, lineStart: Int = 0)(cursor: Int = lineStart, nodeIx: Int = 0): Int = {
|
||||
def startValueBranch(rootValueIx: Int, valueParser: HeaderValueParser) = {
|
||||
val (header, endIx) = valueParser(input, cursor, onIllegalHeader)
|
||||
val (header, endIx) = valueParser(this, input, cursor, onIllegalHeader)
|
||||
if (valueParser.cachingEnabled)
|
||||
try {
|
||||
val valueIx = newValueIndex // compute early in order to trigger OutOfTrieSpaceExceptions before any change
|
||||
|
|
@ -149,7 +150,7 @@ private[engine] final class HttpHeaderParser private (
|
|||
parseHeaderLine(input, lineStart)(cursor, nodeIx)
|
||||
} catch {
|
||||
case OutOfTrieSpaceException ⇒ // if we cannot insert we drop back to simply creating new header instances
|
||||
val (headerValue, endIx) = scanHeaderValue(input, colonIx + 1, colonIx + maxHeaderValueLength + 3)()
|
||||
val (headerValue, endIx) = scanHeaderValue(this, input, colonIx + 1, colonIx + maxHeaderValueLength + 3)()
|
||||
resultHeader = RawHeader(headerName, headerValue.trim)
|
||||
endIx
|
||||
}
|
||||
|
|
@ -158,7 +159,7 @@ private[engine] final class HttpHeaderParser private (
|
|||
@tailrec
|
||||
private def parseHeaderValue(input: ByteString, valueStart: Int, branch: ValueBranch)(cursor: Int = valueStart, nodeIx: Int = branch.branchRootNodeIx): Int = {
|
||||
def parseAndInsertHeader() = {
|
||||
val (header, endIx) = branch.parser(input, valueStart, onIllegalHeader)
|
||||
val (header, endIx) = branch.parser(this, input, valueStart, onIllegalHeader)
|
||||
if (branch.spaceLeft)
|
||||
try {
|
||||
insert(input, header)(cursor, endIx, nodeIx, colonIx = 0)
|
||||
|
|
@ -375,6 +376,30 @@ private[engine] final class HttpHeaderParser private (
|
|||
* Returns a string representation of the trie structure size.
|
||||
*/
|
||||
def formatSizes: String = s"$nodeCount nodes, ${branchDataCount / 3} branchData rows, $valueCount values"
|
||||
|
||||
// helpers for UTF-8 decoding,
|
||||
// since they are only accessed when an UTF8 byte sequence is actually hit and UTF-8 sequences in header values are
|
||||
// rare these fields can be lazy, the overhead of the lazy access should be overcompensated for by the saved
|
||||
// allocations in the majority of cases
|
||||
private lazy val byteBuffer = ByteBuffer.allocate(4)
|
||||
private lazy val charBuffer = CharBuffer.allocate(2)
|
||||
private lazy val decoder = UTF8.newDecoder()
|
||||
|
||||
// returns the decoded character as a simple 16-bit Char value or a 32-bit surrogate pair
|
||||
// or -1 if the byteBuffer bytes are not a complete and legal UTF-8 byte sequence
|
||||
private def decodeByteBuffer(): Int = {
|
||||
byteBuffer.flip()
|
||||
val coderResult = decoder.decode(byteBuffer, charBuffer, false)
|
||||
charBuffer.flip()
|
||||
val result =
|
||||
if (coderResult.isUnderflow & charBuffer.hasRemaining) {
|
||||
val c = charBuffer.get()
|
||||
if (charBuffer.hasRemaining) (charBuffer.get() << 16) | c else c
|
||||
} else -1
|
||||
byteBuffer.clear()
|
||||
charBuffer.clear()
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -446,17 +471,17 @@ private[http] object HttpHeaderParser {
|
|||
def insertRemainingCharsAsNewNodes(parser: HttpHeaderParser, input: ByteString, value: AnyRef): Unit =
|
||||
parser.insertRemainingCharsAsNewNodes(input, value)()
|
||||
|
||||
abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) {
|
||||
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int)
|
||||
private[parsing] abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) {
|
||||
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int)
|
||||
override def toString: String = s"HeaderValueParser[$headerName]"
|
||||
def cachingEnabled = maxValueCount > 0
|
||||
}
|
||||
|
||||
class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings)
|
||||
private[parsing] class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings)
|
||||
extends HeaderValueParser(headerName, maxValueCount) {
|
||||
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
// TODO: optimize by running the header value parser directly on the input ByteString (rather than an extracted String)
|
||||
val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)()
|
||||
val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)()
|
||||
val trimmedHeaderValue = headerValue.trim
|
||||
val header = HeaderParser.parseFull(headerName, trimmedHeaderValue, settings) match {
|
||||
case Right(h) ⇒ h
|
||||
|
|
@ -468,10 +493,10 @@ private[http] object HttpHeaderParser {
|
|||
}
|
||||
}
|
||||
|
||||
class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int)
|
||||
private[parsing] class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int)
|
||||
extends HeaderValueParser(headerName, maxValueCount) {
|
||||
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)()
|
||||
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)()
|
||||
RawHeader(headerName, headerValue.trim) -> endIx
|
||||
}
|
||||
}
|
||||
|
|
@ -485,16 +510,49 @@ private[http] object HttpHeaderParser {
|
|||
}
|
||||
else fail(s"HTTP header name exceeds the configured limit of ${limit - start - 1} characters")
|
||||
|
||||
@tailrec private def scanHeaderValue(input: ByteString, start: Int, limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = {
|
||||
def spaceAppended = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(' ')
|
||||
@tailrec private def scanHeaderValue(hhp: HttpHeaderParser, input: ByteString, start: Int,
|
||||
limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = {
|
||||
def appended(c: Char) = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(c)
|
||||
def appended2(c: Int) = if ((c >> 16) != 0) appended(c.toChar).append((c >> 16).toChar) else appended(c.toChar)
|
||||
if (ix < limit)
|
||||
byteChar(input, ix) match {
|
||||
case '\t' ⇒ scanHeaderValue(input, start, limit)(spaceAppended, ix + 1)
|
||||
case '\t' ⇒ scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 1)
|
||||
case '\r' if byteChar(input, ix + 1) == '\n' ⇒
|
||||
if (WSP(byteChar(input, ix + 2))) scanHeaderValue(input, start, limit)(spaceAppended, ix + 3)
|
||||
if (WSP(byteChar(input, ix + 2))) scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 3)
|
||||
else (if (sb != null) sb.toString else asciiString(input, start, ix), ix + 2)
|
||||
case c if c >= ' ' ⇒ scanHeaderValue(input, start, limit)(if (sb != null) sb.append(c) else sb, ix + 1)
|
||||
case c ⇒ fail(s"Illegal character '${escape(c)}' in header value")
|
||||
case c ⇒
|
||||
var nix = ix + 1
|
||||
val nsb =
|
||||
if (' ' <= c && c <= '\u007F') if (sb != null) sb.append(c) else null // legal 7-Bit ASCII
|
||||
else if ((c & 0xE0) == 0xC0) { // 2-byte UTF-8 sequence?
|
||||
hhp.byteBuffer.put(c.toByte)
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 1))
|
||||
nix = ix + 2
|
||||
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
|
||||
case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)) else null
|
||||
case cc ⇒ appended2(cc)
|
||||
}
|
||||
} else if ((c & 0xF0) == 0xE0) { // 3-byte UTF-8 sequence?
|
||||
hhp.byteBuffer.put(c.toByte)
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 1))
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 2))
|
||||
nix = ix + 3
|
||||
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
|
||||
case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)) else null
|
||||
case cc ⇒ appended2(cc)
|
||||
}
|
||||
} else if ((c & 0xF8) == 0xF0) { // 4-byte UTF-8 sequence?
|
||||
hhp.byteBuffer.put(c.toByte)
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 1))
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 2))
|
||||
hhp.byteBuffer.put(byteAt(input, ix + 3))
|
||||
nix = ix + 4
|
||||
hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy
|
||||
case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)).append(byteChar(input, ix + 3)) else null
|
||||
case cc ⇒ appended2(cc)
|
||||
}
|
||||
} else fail(s"Illegal character '${escape(c)}' in header value")
|
||||
scanHeaderValue(hhp, input, start, limit)(nsb, nix)
|
||||
}
|
||||
else fail(s"HTTP header value exceeds the configured limit of ${limit - start - 2} characters")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ private object SpecializedHeaderValueParsers {
|
|||
def specializedHeaderValueParsers = Seq(ContentLengthParser)
|
||||
|
||||
object ContentLengthParser extends HeaderValueParser("Content-Length", maxValueCount = 1) {
|
||||
def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = {
|
||||
@tailrec def recurse(ix: Int = valueStart, result: Long = 0): (HttpHeader, Int) = {
|
||||
val c = byteChar(input, ix)
|
||||
if (result < 0) fail("`Content-Length` header value must not exceed 63-bit integer range")
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class ContentLengthHeaderParserSpec extends WordSpec with Matchers {
|
|||
}
|
||||
|
||||
def parse(bigint: String): Long = {
|
||||
val (`Content-Length`(length), _) = ContentLengthParser(ByteString(bigint + "\r\n").compact, 0, _ ⇒ ())
|
||||
val (`Content-Length`(length), _) = ContentLengthParser(null, ByteString(bigint + "\r\n").compact, 0, _ ⇒ ())
|
||||
length
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -150,6 +150,13 @@ class HttpHeaderParserSpec extends WordSpec with Matchers with BeforeAndAfterAll
|
|||
parseAndCache("Fancy: foo\tbar\r\nx")() shouldEqual RawHeader("Fancy", "foo bar")
|
||||
}
|
||||
|
||||
"parse and cache a header with UTF8 chars in the value" in new TestSetup() {
|
||||
parseAndCache("2-UTF8-Bytes: árvíztűrő ütvefúrógép\r\nx")() shouldEqual RawHeader("2-UTF8-Bytes", "árvíztűrő ütvefúrógép")
|
||||
parseAndCache("3-UTF8-Bytes: The € or the $?\r\nx")() shouldEqual RawHeader("3-UTF8-Bytes", "The € or the $?")
|
||||
parseAndCache("4-UTF8-Bytes: Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!\r\nx")() shouldEqual
|
||||
RawHeader("4-UTF8-Bytes", "Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!")
|
||||
}
|
||||
|
||||
"produce an error message for lines with an illegal header name" in new TestSetup() {
|
||||
the[ParsingException] thrownBy parseLine(" Connection: close\r\nx") should have message "Illegal character ' ' in header name"
|
||||
the[ParsingException] thrownBy parseLine("Connection : close\r\nx") should have message "Illegal character ' ' in header name"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue