=htc #16388 fix StackOverflowException during chunk parsing
This commit is contained in:
parent
616e8b7ad9
commit
dd27fdd485
2 changed files with 62 additions and 38 deletions
|
|
@ -6,7 +6,6 @@ package akka.http.engine.parsing
|
||||||
|
|
||||||
import scala.annotation.tailrec
|
import scala.annotation.tailrec
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.collection.immutable
|
|
||||||
import akka.parboiled2.CharUtils
|
import akka.parboiled2.CharUtils
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import akka.stream.scaladsl.Source
|
import akka.stream.scaladsl.Source
|
||||||
|
|
@ -25,6 +24,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
|
||||||
import settings._
|
import settings._
|
||||||
|
|
||||||
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
|
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
|
||||||
|
final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult
|
||||||
|
|
||||||
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs
|
private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs
|
||||||
private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0)
|
private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0)
|
||||||
|
|
@ -34,13 +34,20 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
|
||||||
override def initial = new State {
|
override def initial = new State {
|
||||||
override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
|
override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
|
||||||
result.clear()
|
result.clear()
|
||||||
try state(input)
|
|
||||||
catch {
|
@tailrec def run(next: ByteString ⇒ StateResult): StateResult =
|
||||||
case e: ParsingException ⇒ fail(e.status, e.info)
|
(try next(input)
|
||||||
case NotEnoughDataException ⇒
|
catch {
|
||||||
// we are missing a try/catch{continue} wrapper somewhere
|
case e: ParsingException ⇒ fail(e.status, e.info)
|
||||||
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
|
case NotEnoughDataException ⇒
|
||||||
}
|
// we are missing a try/catch{continue} wrapper somewhere
|
||||||
|
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
|
||||||
|
}) match {
|
||||||
|
case Trampoline(next) ⇒ run(next)
|
||||||
|
case x ⇒ x
|
||||||
|
}
|
||||||
|
|
||||||
|
run(state)
|
||||||
val resultIterator = result.iterator
|
val resultIterator = result.iterator
|
||||||
if (terminated) emitAndFinish(resultIterator, ctx)
|
if (terminated) emitAndFinish(resultIterator, ctx)
|
||||||
else emit(resultIterator, ctx)
|
else emit(resultIterator, ctx)
|
||||||
|
|
@ -162,7 +169,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
|
||||||
val chunkBodyEnd = cursor + chunkSize
|
val chunkBodyEnd = cursor + chunkSize
|
||||||
def result(terminatorLen: Int) = {
|
def result(terminatorLen: Int) = {
|
||||||
emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
|
emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
|
||||||
parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)
|
trampoline(_ ⇒ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
|
||||||
}
|
}
|
||||||
byteChar(input, chunkBodyEnd) match {
|
byteChar(input, chunkBodyEnd) match {
|
||||||
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' ⇒ result(2)
|
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' ⇒ result(2)
|
||||||
|
|
@ -213,6 +220,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut
|
||||||
state = next(_, 0)
|
state = next(_, 0)
|
||||||
done()
|
done()
|
||||||
}
|
}
|
||||||
|
def trampoline(next: ByteString ⇒ StateResult): StateResult = Trampoline(next)
|
||||||
|
|
||||||
def fail(summary: String): StateResult = fail(summary, "")
|
def fail(summary: String): StateResult = fail(summary, "")
|
||||||
def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail)
|
def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail)
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,21 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|
||||||
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))))
|
source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo"))))))))
|
||||||
closeAfterResponseCompletion shouldEqual Seq(false)
|
closeAfterResponseCompletion shouldEqual Seq(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"don't overflow the stack for large buffers of chunks" in new Test {
|
||||||
|
override val awaitAtMost = 3000.millis
|
||||||
|
|
||||||
|
val x = NotEnoughDataException
|
||||||
|
val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m`
|
||||||
|
val oneChunk = "1\r\nz\n"
|
||||||
|
val manyChunks = (oneChunk * numChunks) + "0\r\n"
|
||||||
|
|
||||||
|
val parser = newParser
|
||||||
|
val result = multiParse(newParser)(Seq(prep(start + manyChunks)))
|
||||||
|
val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity
|
||||||
|
val strictChunks = chunks.collectAll.awaitResult(awaitAtMost)
|
||||||
|
strictChunks.size shouldEqual numChunks
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
"properly parse a chunked request with additional transfer encodings" in new Test {
|
"properly parse a chunked request with additional transfer encodings" in new Test {
|
||||||
|
|
@ -374,14 +389,15 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|
||||||
override def afterAll() = system.shutdown()
|
override def afterAll() = system.shutdown()
|
||||||
|
|
||||||
private class Test {
|
private class Test {
|
||||||
|
def awaitAtMost: FiniteDuration = 250.millis
|
||||||
var closeAfterResponseCompletion = Seq.empty[Boolean]
|
var closeAfterResponseCompletion = Seq.empty[Boolean]
|
||||||
|
|
||||||
class StrictEqualHttpRequest(val req: HttpRequest) {
|
class StrictEqualHttpRequest(val req: HttpRequest) {
|
||||||
override def equals(other: scala.Any): Boolean = other match {
|
override def equals(other: scala.Any): Boolean = other match {
|
||||||
case other: StrictEqualHttpRequest ⇒
|
case other: StrictEqualHttpRequest ⇒
|
||||||
this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) &&
|
this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) &&
|
||||||
Await.result(this.req.entity.toStrict(250.millis), 250.millis) ==
|
this.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) ==
|
||||||
Await.result(other.req.entity.toStrict(250.millis), 250.millis)
|
other.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString = req.toString
|
override def toString = req.toString
|
||||||
|
|
@ -414,32 +430,32 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|
||||||
def generalRawMultiParseTo(parser: HttpRequestParser,
|
def generalRawMultiParseTo(parser: HttpRequestParser,
|
||||||
expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] =
|
expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] =
|
||||||
equal(expected.map(strictEqualify))
|
equal(expected.map(strictEqualify))
|
||||||
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose { input: Seq[String] ⇒
|
.matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser)
|
||||||
val future =
|
|
||||||
Source(input.toList)
|
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] =
|
||||||
.map(ByteString.apply)
|
Source(input.toList)
|
||||||
.transform("parser", () ⇒ parser)
|
.map(ByteString.apply)
|
||||||
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart])
|
.transform("parser", () ⇒ parser)
|
||||||
.headAndTail
|
.splitWhen(_.isInstanceOf[ParserOutput.MessageStart])
|
||||||
.collect {
|
.headAndTail
|
||||||
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) ⇒
|
.collect {
|
||||||
closeAfterResponseCompletion :+= close
|
case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) ⇒
|
||||||
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
|
closeAfterResponseCompletion :+= close
|
||||||
case (x: ParseError, _) ⇒ Left(x)
|
Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol))
|
||||||
}
|
case (x: ParseError, _) ⇒ Left(x)
|
||||||
.map { x ⇒
|
|
||||||
Source {
|
|
||||||
x match {
|
|
||||||
case Right(request) ⇒ compactEntity(request.entity).fast.map(x ⇒ Right(request.withEntity(x)))
|
|
||||||
case Left(error) ⇒ Future.successful(Left(error))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.flatten(FlattenStrategy.concat)
|
|
||||||
.map(strictEqualify)
|
|
||||||
.grouped(1000).runWith(Sink.head)
|
|
||||||
Await.result(future, 250.millis)
|
|
||||||
}
|
}
|
||||||
|
.map { x ⇒
|
||||||
|
Source {
|
||||||
|
x match {
|
||||||
|
case Right(request) ⇒ compactEntity(request.entity).fast.map(x ⇒ Right(request.withEntity(x)))
|
||||||
|
case Left(error) ⇒ Future.successful(Left(error))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.flatten(FlattenStrategy.concat)
|
||||||
|
.map(strictEqualify)
|
||||||
|
.collectAll
|
||||||
|
.awaitResult(awaitAtMost)
|
||||||
|
|
||||||
protected def parserSettings: ParserSettings = ParserSettings(system)
|
protected def parserSettings: ParserSettings = ParserSettings(system)
|
||||||
protected def newParser = new HttpRequestParser(parserSettings, false)()
|
protected def newParser = new HttpRequestParser(parserSettings, false)()
|
||||||
|
|
@ -447,11 +463,11 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|
||||||
private def compactEntity(entity: RequestEntity): Future[RequestEntity] =
|
private def compactEntity(entity: RequestEntity): Future[RequestEntity] =
|
||||||
entity match {
|
entity match {
|
||||||
case x: Chunked ⇒ compactEntityChunks(x.chunks).fast.map(compacted ⇒ x.copy(chunks = source(compacted: _*)))
|
case x: Chunked ⇒ compactEntityChunks(x.chunks).fast.map(compacted ⇒ x.copy(chunks = source(compacted: _*)))
|
||||||
case _ ⇒ entity.toStrict(250.millis)
|
case _ ⇒ entity.toStrict(awaitAtMost)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
|
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
|
||||||
data.grouped(1000).runWith(Sink.head)
|
data.collectAll
|
||||||
.fast.recover { case _: NoSuchElementException ⇒ Nil }
|
.fast.recover { case _: NoSuchElementException ⇒ Nil }
|
||||||
|
|
||||||
def prep(response: String) = response.stripMarginWithNewline("\r\n")
|
def prep(response: String) = response.stripMarginWithNewline("\r\n")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue