diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpHeaderParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpHeaderParser.scala index b63d110ea2..0ab613fc90 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpHeaderParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpHeaderParser.scala @@ -4,6 +4,7 @@ package akka.http.impl.engine.parsing +import java.nio.{ CharBuffer, ByteBuffer } import java.util.Arrays.copyOf import java.lang.{ StringBuilder ⇒ JStringBuilder } import scala.annotation.tailrec @@ -98,7 +99,7 @@ private[engine] final class HttpHeaderParser private ( @tailrec def parseHeaderLine(input: ByteString, lineStart: Int = 0)(cursor: Int = lineStart, nodeIx: Int = 0): Int = { def startValueBranch(rootValueIx: Int, valueParser: HeaderValueParser) = { - val (header, endIx) = valueParser(input, cursor, onIllegalHeader) + val (header, endIx) = valueParser(this, input, cursor, onIllegalHeader) if (valueParser.cachingEnabled) try { val valueIx = newValueIndex // compute early in order to trigger OutOfTrieSpaceExceptions before any change @@ -149,7 +150,7 @@ private[engine] final class HttpHeaderParser private ( parseHeaderLine(input, lineStart)(cursor, nodeIx) } catch { case OutOfTrieSpaceException ⇒ // if we cannot insert we drop back to simply creating new header instances - val (headerValue, endIx) = scanHeaderValue(input, colonIx + 1, colonIx + maxHeaderValueLength + 3)() + val (headerValue, endIx) = scanHeaderValue(this, input, colonIx + 1, colonIx + maxHeaderValueLength + 3)() resultHeader = RawHeader(headerName, headerValue.trim) endIx } @@ -158,7 +159,7 @@ private[engine] final class HttpHeaderParser private ( @tailrec private def parseHeaderValue(input: ByteString, valueStart: Int, branch: ValueBranch)(cursor: Int = valueStart, nodeIx: Int = branch.branchRootNodeIx): Int = { def parseAndInsertHeader() = { - val (header, endIx) = branch.parser(input, valueStart, onIllegalHeader) + val (header, endIx) = branch.parser(this, input, valueStart, onIllegalHeader) if (branch.spaceLeft) try { insert(input, header)(cursor, endIx, nodeIx, colonIx = 0) @@ -375,6 +376,30 @@ private[engine] final class HttpHeaderParser private ( * Returns a string representation of the trie structure size. */ def formatSizes: String = s"$nodeCount nodes, ${branchDataCount / 3} branchData rows, $valueCount values" + + // helpers for UTF-8 decoding, + // since they are only accessed when an UTF8 byte sequence is actually hit and UTF-8 sequences in header values are + // rare these fields can be lazy, the overhead of the lazy access should be overcompensated for by the saved + // allocations in the majority of cases + private lazy val byteBuffer = ByteBuffer.allocate(4) + private lazy val charBuffer = CharBuffer.allocate(2) + private lazy val decoder = UTF8.newDecoder() + + // returns the decoded character as a simple 16-bit Char value or a 32-bit surrogate pair + // or -1 if the byteBuffer bytes are not a complete and legal UTF-8 byte sequence + private def decodeByteBuffer(): Int = { + byteBuffer.flip() + val coderResult = decoder.decode(byteBuffer, charBuffer, false) + charBuffer.flip() + val result = + if (coderResult.isUnderflow & charBuffer.hasRemaining) { + val c = charBuffer.get() + if (charBuffer.hasRemaining) (charBuffer.get() << 16) | c else c + } else -1 + byteBuffer.clear() + charBuffer.clear() + result + } } /** @@ -446,17 +471,17 @@ private[http] object HttpHeaderParser { def insertRemainingCharsAsNewNodes(parser: HttpHeaderParser, input: ByteString, value: AnyRef): Unit = parser.insertRemainingCharsAsNewNodes(input, value)() - abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) { - def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) + private[parsing] abstract class HeaderValueParser(val headerName: String, val maxValueCount: Int) { + def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) override def toString: String = s"HeaderValueParser[$headerName]" def cachingEnabled = maxValueCount > 0 } - class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings) + private[parsing] class ModelledHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int, settings: HeaderParser.Settings) extends HeaderValueParser(headerName, maxValueCount) { - def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { + def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { // TODO: optimize by running the header value parser directly on the input ByteString (rather than an extracted String) - val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)() + val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)() val trimmedHeaderValue = headerValue.trim val header = HeaderParser.parseFull(headerName, trimmedHeaderValue, settings) match { case Right(h) ⇒ h @@ -468,10 +493,10 @@ private[http] object HttpHeaderParser { } } - class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int) + private[parsing] class RawHeaderValueParser(headerName: String, maxHeaderValueLength: Int, maxValueCount: Int) extends HeaderValueParser(headerName, maxValueCount) { - def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { - val (headerValue, endIx) = scanHeaderValue(input, valueStart, valueStart + maxHeaderValueLength + 2)() + def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { + val (headerValue, endIx) = scanHeaderValue(hhp, input, valueStart, valueStart + maxHeaderValueLength + 2)() RawHeader(headerName, headerValue.trim) -> endIx } } @@ -485,16 +510,49 @@ private[http] object HttpHeaderParser { } else fail(s"HTTP header name exceeds the configured limit of ${limit - start - 1} characters") - @tailrec private def scanHeaderValue(input: ByteString, start: Int, limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = { - def spaceAppended = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(' ') + @tailrec private def scanHeaderValue(hhp: HttpHeaderParser, input: ByteString, start: Int, + limit: Int)(sb: JStringBuilder = null, ix: Int = start): (String, Int) = { + def appended(c: Char) = (if (sb != null) sb else new JStringBuilder(asciiString(input, start, ix))).append(c) + def appended2(c: Int) = if ((c >> 16) != 0) appended(c.toChar).append((c >> 16).toChar) else appended(c.toChar) if (ix < limit) byteChar(input, ix) match { - case '\t' ⇒ scanHeaderValue(input, start, limit)(spaceAppended, ix + 1) + case '\t' ⇒ scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 1) case '\r' if byteChar(input, ix + 1) == '\n' ⇒ - if (WSP(byteChar(input, ix + 2))) scanHeaderValue(input, start, limit)(spaceAppended, ix + 3) + if (WSP(byteChar(input, ix + 2))) scanHeaderValue(hhp, input, start, limit)(appended(' '), ix + 3) else (if (sb != null) sb.toString else asciiString(input, start, ix), ix + 2) - case c if c >= ' ' ⇒ scanHeaderValue(input, start, limit)(if (sb != null) sb.append(c) else sb, ix + 1) - case c ⇒ fail(s"Illegal character '${escape(c)}' in header value") + case c ⇒ + var nix = ix + 1 + val nsb = + if (' ' <= c && c <= '\u007F') if (sb != null) sb.append(c) else null // legal 7-Bit ASCII + else if ((c & 0xE0) == 0xC0) { // 2-byte UTF-8 sequence? + hhp.byteBuffer.put(c.toByte) + hhp.byteBuffer.put(byteAt(input, ix + 1)) + nix = ix + 2 + hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy + case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)) else null + case cc ⇒ appended2(cc) + } + } else if ((c & 0xF0) == 0xE0) { // 3-byte UTF-8 sequence? + hhp.byteBuffer.put(c.toByte) + hhp.byteBuffer.put(byteAt(input, ix + 1)) + hhp.byteBuffer.put(byteAt(input, ix + 2)) + nix = ix + 3 + hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy + case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)) else null + case cc ⇒ appended2(cc) + } + } else if ((c & 0xF8) == 0xF0) { // 4-byte UTF-8 sequence? + hhp.byteBuffer.put(c.toByte) + hhp.byteBuffer.put(byteAt(input, ix + 1)) + hhp.byteBuffer.put(byteAt(input, ix + 2)) + hhp.byteBuffer.put(byteAt(input, ix + 3)) + nix = ix + 4 + hhp.decodeByteBuffer() match { // if we cannot decode as UTF8 we don't decode but simply copy + case -1 ⇒ if (sb != null) sb.append(c).append(byteChar(input, ix + 1)).append(byteChar(input, ix + 2)).append(byteChar(input, ix + 3)) else null + case cc ⇒ appended2(cc) + } + } else fail(s"Illegal character '${escape(c)}' in header value") + scanHeaderValue(hhp, input, start, limit)(nsb, nix) } else fail(s"HTTP header value exceeds the configured limit of ${limit - start - 2} characters") } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/SpecializedHeaderValueParsers.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/SpecializedHeaderValueParsers.scala index 229a013135..bde90d02e4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/SpecializedHeaderValueParsers.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/SpecializedHeaderValueParsers.scala @@ -19,7 +19,7 @@ private object SpecializedHeaderValueParsers { def specializedHeaderValueParsers = Seq(ContentLengthParser) object ContentLengthParser extends HeaderValueParser("Content-Length", maxValueCount = 1) { - def apply(input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { + def apply(hhp: HttpHeaderParser, input: ByteString, valueStart: Int, onIllegalHeader: ErrorInfo ⇒ Unit): (HttpHeader, Int) = { @tailrec def recurse(ix: Int = valueStart, result: Long = 0): (HttpHeader, Int) = { val c = byteChar(input, ix) if (result < 0) fail("`Content-Length` header value must not exceed 63-bit integer range") diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ContentLengthHeaderParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ContentLengthHeaderParserSpec.scala index 687ffb2bf1..54b3e1d2c0 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ContentLengthHeaderParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ContentLengthHeaderParserSpec.scala @@ -30,7 +30,7 @@ class ContentLengthHeaderParserSpec extends WordSpec with Matchers { } def parse(bigint: String): Long = { - val (`Content-Length`(length), _) = ContentLengthParser(ByteString(bigint + "\r\n").compact, 0, _ ⇒ ()) + val (`Content-Length`(length), _) = ContentLengthParser(null, ByteString(bigint + "\r\n").compact, 0, _ ⇒ ()) length } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/HttpHeaderParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/HttpHeaderParserSpec.scala index 513dd9bb32..e0e631c5c7 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/HttpHeaderParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/HttpHeaderParserSpec.scala @@ -150,6 +150,13 @@ class HttpHeaderParserSpec extends WordSpec with Matchers with BeforeAndAfterAll parseAndCache("Fancy: foo\tbar\r\nx")() shouldEqual RawHeader("Fancy", "foo bar") } + "parse and cache a header with UTF8 chars in the value" in new TestSetup() { + parseAndCache("2-UTF8-Bytes: árvíztűrő ütvefúrógép\r\nx")() shouldEqual RawHeader("2-UTF8-Bytes", "árvíztűrő ütvefúrógép") + parseAndCache("3-UTF8-Bytes: The € or the $?\r\nx")() shouldEqual RawHeader("3-UTF8-Bytes", "The € or the $?") + parseAndCache("4-UTF8-Bytes: Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!\r\nx")() shouldEqual + RawHeader("4-UTF8-Bytes", "Surrogate pairs: \uD801\uDC1B\uD801\uDC04\uD801\uDC1B!") + } + "produce an error message for lines with an illegal header name" in new TestSetup() { the[ParsingException] thrownBy parseLine(" Connection: close\r\nx") should have message "Illegal character ' ' in header name" the[ParsingException] thrownBy parseLine("Connection : close\r\nx") should have message "Illegal character ' ' in header name"