=htc #16388 fix StackOverflowException during chunk parsing

This commit is contained in:
Johannes Rudolph 2014-11-25 15:39:46 +01:00
parent 616e8b7ad9
commit dd27fdd485
2 changed files with 62 additions and 38 deletions

View file

@ -6,7 +6,6 @@ package akka.http.engine.parsing
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import scala.collection.immutable
import akka.parboiled2.CharUtils import akka.parboiled2.CharUtils
import akka.util.ByteString import akka.util.ByteString
import akka.stream.scaladsl.Source import akka.stream.scaladsl.Source
@ -25,6 +24,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
import settings._ import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs
private[this] var state: ByteString StateResult = startNewMessage(_, 0) private[this] var state: ByteString StateResult = startNewMessage(_, 0)
@ -34,13 +34,20 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
override def initial = new State { override def initial = new State {
override def onPush(input: ByteString, ctx: Context[Output]): Directive = { override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
result.clear() result.clear()
try state(input)
catch { @tailrec def run(next: ByteString StateResult): StateResult =
case e: ParsingException fail(e.status, e.info) (try next(input)
case NotEnoughDataException catch {
// we are missing a try/catch{continue} wrapper somewhere case e: ParsingException fail(e.status, e.info)
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) case NotEnoughDataException
} // we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}) match {
case Trampoline(next) run(next)
case x x
}
run(state)
val resultIterator = result.iterator val resultIterator = result.iterator
if (terminated) emitAndFinish(resultIterator, ctx) if (terminated) emitAndFinish(resultIterator, ctx)
else emit(resultIterator, ctx) else emit(resultIterator, ctx)
@ -162,7 +169,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
val chunkBodyEnd = cursor + chunkSize val chunkBodyEnd = cursor + chunkSize
def result(terminatorLen: Int) = { def result(terminatorLen: Int) = {
emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage) trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
} }
byteChar(input, chunkBodyEnd) match { byteChar(input, chunkBodyEnd) match {
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2) case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2)
@ -213,6 +220,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
state = next(_, 0) state = next(_, 0)
done() done()
} }
def trampoline(next: ByteString StateResult): StateResult = Trampoline(next)
def fail(summary: String): StateResult = fail(summary, "") def fail(summary: String): StateResult = fail(summary, "")
def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail) def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail)

View file

@ -215,6 +215,21 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
"don't overflow the stack for large buffers of chunks" in new Test {
override val awaitAtMost = 3000.millis
val x = NotEnoughDataException
val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m`
val oneChunk = "1\r\nz\n"
val manyChunks = (oneChunk * numChunks) + "0\r\n"
val parser = newParser
val result = multiParse(newParser)(Seq(prep(start + manyChunks)))
val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity
val strictChunks = chunks.collectAll.awaitResult(awaitAtMost)
strictChunks.size shouldEqual numChunks
}
} }
"properly parse a chunked request with additional transfer encodings" in new Test { "properly parse a chunked request with additional transfer encodings" in new Test {
@ -374,14 +389,15 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
override def afterAll() = system.shutdown() override def afterAll() = system.shutdown()
private class Test { private class Test {
def awaitAtMost: FiniteDuration = 250.millis
var closeAfterResponseCompletion = Seq.empty[Boolean] var closeAfterResponseCompletion = Seq.empty[Boolean]
class StrictEqualHttpRequest(val req: HttpRequest) { class StrictEqualHttpRequest(val req: HttpRequest) {
override def equals(other: scala.Any): Boolean = other match { override def equals(other: scala.Any): Boolean = other match {
case other: StrictEqualHttpRequest case other: StrictEqualHttpRequest
this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) && this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) &&
Await.result(this.req.entity.toStrict(250.millis), 250.millis) == this.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) ==
Await.result(other.req.entity.toStrict(250.millis), 250.millis) other.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost)
} }
override def toString = req.toString override def toString = req.toString
@ -414,32 +430,32 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def generalRawMultiParseTo(parser: HttpRequestParser, def generalRawMultiParseTo(parser: HttpRequestParser,
expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] =
equal(expected.map(strictEqualify)) equal(expected.map(strictEqualify))
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose { input: Seq[String] .matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser)
val future =
Source(input.toList) def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] =
.map(ByteString.apply) Source(input.toList)
.transform("parser", () parser) .map(ByteString.apply)
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) .transform("parser", () parser)
.headAndTail .splitWhen(_.isInstanceOf[ParserOutput.MessageStart])
.collect { .headAndTail
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) .collect {
closeAfterResponseCompletion :+= close case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts)
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) closeAfterResponseCompletion :+= close
case (x: ParseError, _) Left(x) Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
} case (x: ParseError, _) Left(x)
.map { x
Source {
x match {
case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x)))
case Left(error) Future.successful(Left(error))
}
}
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.grouped(1000).runWith(Sink.head)
Await.result(future, 250.millis)
} }
.map { x
Source {
x match {
case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x)))
case Left(error) Future.successful(Left(error))
}
}
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.collectAll
.awaitResult(awaitAtMost)
protected def parserSettings: ParserSettings = ParserSettings(system) protected def parserSettings: ParserSettings = ParserSettings(system)
protected def newParser = new HttpRequestParser(parserSettings, false)() protected def newParser = new HttpRequestParser(parserSettings, false)()
@ -447,11 +463,11 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
private def compactEntity(entity: RequestEntity): Future[RequestEntity] = private def compactEntity(entity: RequestEntity): Future[RequestEntity] =
entity match { entity match {
case x: Chunked compactEntityChunks(x.chunks).fast.map(compacted x.copy(chunks = source(compacted: _*))) case x: Chunked compactEntityChunks(x.chunks).fast.map(compacted x.copy(chunks = source(compacted: _*)))
case _ entity.toStrict(250.millis) case _ entity.toStrict(awaitAtMost)
} }
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] = private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
data.grouped(1000).runWith(Sink.head) data.collectAll
.fast.recover { case _: NoSuchElementException Nil } .fast.recover { case _: NoSuchElementException Nil }
def prep(response: String) = response.stripMarginWithNewline("\r\n") def prep(response: String) = response.stripMarginWithNewline("\r\n")