=htc #16388 fix StackOverflowException during chunk parsing

This commit is contained in:
Johannes Rudolph 2014-11-25 15:39:46 +01:00
parent 616e8b7ad9
commit dd27fdd485
2 changed files with 62 additions and 38 deletions

View file

@ -6,7 +6,6 @@ package akka.http.engine.parsing
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.collection.immutable
import akka.parboiled2.CharUtils
import akka.util.ByteString
import akka.stream.scaladsl.Source
@ -25,6 +24,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs
private[this] var state: ByteString StateResult = startNewMessage(_, 0)
@ -34,13 +34,20 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
override def initial = new State {
override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
result.clear()
try state(input)
catch {
case e: ParsingException fail(e.status, e.info)
case NotEnoughDataException
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}
@tailrec def run(next: ByteString StateResult): StateResult =
(try next(input)
catch {
case e: ParsingException fail(e.status, e.info)
case NotEnoughDataException
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}) match {
case Trampoline(next) run(next)
case x x
}
run(state)
val resultIterator = result.iterator
if (terminated) emitAndFinish(resultIterator, ctx)
else emit(resultIterator, ctx)
@ -162,7 +169,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
val chunkBodyEnd = cursor + chunkSize
def result(terminatorLen: Int) = {
emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)
trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
}
byteChar(input, chunkBodyEnd) match {
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2)
@ -213,6 +220,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
state = next(_, 0)
done()
}
def trampoline(next: ByteString StateResult): StateResult = Trampoline(next)
def fail(summary: String): StateResult = fail(summary, "")
def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail)

View file

@ -215,6 +215,21 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))))
closeAfterResponseCompletion shouldEqual Seq(false)
}
"don't overflow the stack for large buffers of chunks" in new Test {
override val awaitAtMost = 3000.millis
val x = NotEnoughDataException
val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m`
val oneChunk = "1\r\nz\n"
val manyChunks = (oneChunk * numChunks) + "0\r\n"
val parser = newParser
val result = multiParse(newParser)(Seq(prep(start + manyChunks)))
val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity
val strictChunks = chunks.collectAll.awaitResult(awaitAtMost)
strictChunks.size shouldEqual numChunks
}
}
"properly parse a chunked request with additional transfer encodings" in new Test {
@ -374,14 +389,15 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
override def afterAll() = system.shutdown()
private class Test {
def awaitAtMost: FiniteDuration = 250.millis
var closeAfterResponseCompletion = Seq.empty[Boolean]
class StrictEqualHttpRequest(val req: HttpRequest) {
override def equals(other: scala.Any): Boolean = other match {
case other: StrictEqualHttpRequest
this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) &&
Await.result(this.req.entity.toStrict(250.millis), 250.millis) ==
Await.result(other.req.entity.toStrict(250.millis), 250.millis)
this.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) ==
other.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost)
}
override def toString = req.toString
@ -414,32 +430,32 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def generalRawMultiParseTo(parser: HttpRequestParser,
expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] =
equal(expected.map(strictEqualify))
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose { input: Seq[String]
val future =
Source(input.toList)
.map(ByteString.apply)
.transform("parser", () parser)
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart])
.headAndTail
.collect {
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts)
closeAfterResponseCompletion :+= close
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
case (x: ParseError, _) Left(x)
}
.map { x
Source {
x match {
case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x)))
case Left(error) Future.successful(Left(error))
}
}
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.grouped(1000).runWith(Sink.head)
Await.result(future, 250.millis)
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser)
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] =
Source(input.toList)
.map(ByteString.apply)
.transform("parser", () parser)
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart])
.headAndTail
.collect {
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts)
closeAfterResponseCompletion :+= close
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
case (x: ParseError, _) Left(x)
}
.map { x
Source {
x match {
case Right(request) compactEntity(request.entity).fast.map(x Right(request.withEntity(x)))
case Left(error) Future.successful(Left(error))
}
}
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.collectAll
.awaitResult(awaitAtMost)
protected def parserSettings: ParserSettings = ParserSettings(system)
protected def newParser = new HttpRequestParser(parserSettings, false)()
@ -447,11 +463,11 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
private def compactEntity(entity: RequestEntity): Future[RequestEntity] =
entity match {
case x: Chunked compactEntityChunks(x.chunks).fast.map(compacted x.copy(chunks = source(compacted: _*)))
case _ entity.toStrict(250.millis)
case _ entity.toStrict(awaitAtMost)
}
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
data.grouped(1000).runWith(Sink.head)
data.collectAll
.fast.recover { case _: NoSuchElementException Nil }
def prep(response: String) = response.stripMarginWithNewline("\r\n")