diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala index a1d5359f64..1d4feb82e1 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala @@ -6,7 +6,6 @@ package akka.http.engine.parsing import scala.annotation.tailrec import scala.collection.mutable.ListBuffer -import scala.collection.immutable import akka.parboiled2.CharUtils import akka.util.ByteString import akka.stream.scaladsl.Source @@ -25,6 +24,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut import settings._ sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup + final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0) @@ -34,13 +34,20 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut override def initial = new State { override def onPush(input: ByteString, ctx: Context[Output]): Directive = { result.clear() - try state(input) - catch { - case e: ParsingException ⇒ fail(e.status, e.info) - case NotEnoughDataException ⇒ - // we are missing a try/catch{continue} wrapper somewhere - throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) - } + + @tailrec def run(next: ByteString ⇒ StateResult): StateResult = + (try next(input) + catch { + case e: ParsingException ⇒ fail(e.status, e.info) + case NotEnoughDataException ⇒ + // we are missing a try/catch{continue} wrapper somewhere + throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) + }) match { + case Trampoline(next) ⇒ run(next) + case x ⇒ x + } + + run(state) val resultIterator = result.iterator if (terminated) emitAndFinish(resultIterator, ctx) else emit(resultIterator, ctx) @@ -162,7 +169,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut val chunkBodyEnd = cursor + chunkSize def result(terminatorLen: Int) = { emit(ParserOutput.EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) - parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage) + trampoline(_ ⇒ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) } byteChar(input, chunkBodyEnd) match { case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' ⇒ result(2) @@ -213,6 +220,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut state = next(_, 0) done() } + def trampoline(next: ByteString ⇒ StateResult): StateResult = Trampoline(next) def fail(summary: String): StateResult = fail(summary, "") def fail(summary: String, detail: String): StateResult = fail(StatusCodes.BadRequest, summary, detail) diff --git a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala index e00d99ec7e..9a668b0f1d 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala @@ -215,6 +215,21 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { source(LastChunk("nice=true", List(RawHeader("Bar", "xyz"), RawHeader("Foo", "pip apo")))))))) closeAfterResponseCompletion shouldEqual Seq(false) } + + "don't overflow the stack for large buffers of chunks" in new Test { + override val awaitAtMost = 3000.millis + + val x = NotEnoughDataException + val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m` + val oneChunk = "1\r\nz\n" + val manyChunks = (oneChunk * numChunks) + "0\r\n" + + val parser = newParser + val result = multiParse(newParser)(Seq(prep(start + manyChunks))) + val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity + val strictChunks = chunks.collectAll.awaitResult(awaitAtMost) + strictChunks.size shouldEqual numChunks + } } "properly parse a chunked request with additional transfer encodings" in new Test { @@ -374,14 +389,15 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { override def afterAll() = system.shutdown() private class Test { + def awaitAtMost: FiniteDuration = 250.millis var closeAfterResponseCompletion = Seq.empty[Boolean] class StrictEqualHttpRequest(val req: HttpRequest) { override def equals(other: scala.Any): Boolean = other match { case other: StrictEqualHttpRequest ⇒ this.req.copy(entity = HttpEntity.Empty) == other.req.copy(entity = HttpEntity.Empty) && - Await.result(this.req.entity.toStrict(250.millis), 250.millis) == - Await.result(other.req.entity.toStrict(250.millis), 250.millis) + this.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) == + other.req.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) } override def toString = req.toString @@ -414,32 +430,32 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def generalRawMultiParseTo(parser: HttpRequestParser, expected: Either[ParseError, HttpRequest]*): Matcher[Seq[String]] = equal(expected.map(strictEqualify)) - .matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose { input: Seq[String] ⇒ - val future = - Source(input.toList) - .map(ByteString.apply) - .transform("parser", () ⇒ parser) - .splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) - .headAndTail - .collect { - case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) ⇒ - closeAfterResponseCompletion :+= close - Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) - case (x: ParseError, _) ⇒ Left(x) - } - .map { x ⇒ - Source { - x match { - case Right(request) ⇒ compactEntity(request.entity).fast.map(x ⇒ Right(request.withEntity(x))) - case Left(error) ⇒ Future.successful(Left(error)) - } - } - } - .flatten(FlattenStrategy.concat) - .map(strictEqualify) - .grouped(1000).runWith(Sink.head) - Await.result(future, 250.millis) + .matcher[Seq[Either[ParseError, StrictEqualHttpRequest]]] compose multiParse(parser) + + def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[ParseError, StrictEqualHttpRequest]] = + Source(input.toList) + .map(ByteString.apply) + .transform("parser", () ⇒ parser) + .splitWhen(_.isInstanceOf[ParserOutput.MessageStart]) + .headAndTail + .collect { + case (ParserOutput.RequestStart(method, uri, protocol, headers, createEntity, close), entityParts) ⇒ + closeAfterResponseCompletion :+= close + Right(HttpRequest(method, uri, headers, createEntity(entityParts), protocol)) + case (x: ParseError, _) ⇒ Left(x) } + .map { x ⇒ + Source { + x match { + case Right(request) ⇒ compactEntity(request.entity).fast.map(x ⇒ Right(request.withEntity(x))) + case Left(error) ⇒ Future.successful(Left(error)) + } + } + } + .flatten(FlattenStrategy.concat) + .map(strictEqualify) + .collectAll + .awaitResult(awaitAtMost) protected def parserSettings: ParserSettings = ParserSettings(system) protected def newParser = new HttpRequestParser(parserSettings, false)() @@ -447,11 +463,11 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { private def compactEntity(entity: RequestEntity): Future[RequestEntity] = entity match { case x: Chunked ⇒ compactEntityChunks(x.chunks).fast.map(compacted ⇒ x.copy(chunks = source(compacted: _*))) - case _ ⇒ entity.toStrict(250.millis) + case _ ⇒ entity.toStrict(awaitAtMost) } private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] = - data.grouped(1000).runWith(Sink.head) + data.collectAll .fast.recover { case _: NoSuchElementException ⇒ Nil } def prep(response: String) = response.stripMarginWithNewline("\r\n")