Migrate BodyPartParser to GraphStage #20288.

This commit is contained in:
Daniel Moran 2016-04-18 07:43:54 -07:00
parent 9af39f97ba
commit c8352c04c4
4 changed files with 245 additions and 226 deletions

View file

@ -5,6 +5,7 @@
package akka.http.impl.engine.parsing
import akka.NotUsed
import scala.annotation.tailrec
import akka.event.LoggingAdapter
import akka.parboiled2.CharPredicate
@ -13,7 +14,9 @@ import akka.stream.stage._
import akka.util.ByteString
import akka.http.scaladsl.model._
import akka.http.impl.util._
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
import headers._
import scala.collection.mutable.ListBuffer
import akka.stream.impl.fusing.SubSource
@ -26,7 +29,7 @@ private[http] final class BodyPartParser(defaultContentType: ContentType,
boundary: String,
log: LoggingAdapter,
settings: BodyPartParser.Settings)
extends PushPullStage[ByteString, BodyPartParser.Output] {
extends GraphStage[FlowShape[ByteString, BodyPartParser.Output]] {
import BodyPartParser._
import settings._
@ -57,202 +60,214 @@ private[http] final class BodyPartParser(defaultContentType: ContentType,
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
}
private[this] var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful
private[this] var state: ByteString StateResult = tryParseInitialBoundary
private[this] var terminated = false
val in = Inlet[ByteString]("BodyPartParser.in")
val out = Outlet[BodyPartParser.Output]("BodyPartParser.out")
def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit =
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
override val shape = FlowShape(in, out)
override def onPush(input: ByteString, ctx: Context[Output]): SyncDirective =
if (!terminated) {
try state(input)
catch {
case e: ParsingException fail(e.info)
case NotEnoughDataException
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful
private var state: ByteString => StateResult = tryParseInitialBoundary
private var shouldTerminate = false
override def onPush(): Unit = {
if (!shouldTerminate) {
val elem = grab(in)
try state(elem)
catch {
case e: ParsingException => fail(e.info)
case NotEnoughDataException =>
// we are missing a try/catch{continue} wrapper somewhere
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
}
if (output.nonEmpty) push(out, dequeue())
else if (!shouldTerminate) pull(in)
else completeStage()
} else completeStage()
}
if (output.nonEmpty) ctx.push(dequeue())
else if (!terminated) ctx.pull()
else ctx.finish()
} else ctx.finish()
override def onPull(ctx: Context[Output]): SyncDirective = {
if (output.nonEmpty)
ctx.push(dequeue())
else if (ctx.isFinishing) {
if (terminated) ctx.finish()
else ctx.pushAndFinish(ParseError(ErrorInfo("Unexpected end of multipart entity")))
} else
ctx.pull()
}
override def onPull(): Unit = {
if (output.nonEmpty) push(out, dequeue())
else if (isClosed(in)) {
if (!shouldTerminate) push(out, ParseError(ErrorInfo("Unexpected end of multipart entity")))
completeStage()
} else pull(in)
}
override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective = ctx.absorbTermination()
override def onUpstreamFinish(): Unit = {
if (isAvailable(out)) onPull()
}
def tryParseInitialBoundary(input: ByteString): StateResult =
// we don't use boyerMoore here because we are testing for the boundary *without* a
// preceding CRLF and at a known location (the very beginning of the entity)
try {
if (boundary(input, 0)) {
val ix = boundaryLength
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
else if (doubleDash(input, ix)) terminate()
else parsePreamble(input, 0)
} else parsePreamble(input, 0)
} catch {
case NotEnoughDataException continue(input, 0)((newInput, _) tryParseInitialBoundary(newInput))
setHandlers(in, out, this)
def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit =
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
def tryParseInitialBoundary(input: ByteString): StateResult =
// we don't use boyerMoore here because we are testing for the boundary *without* a
// preceding CRLF and at a known location (the very beginning of the entity)
try {
if (boundary(input, 0)) {
val ix = boundaryLength
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
else if (doubleDash(input, ix)) setShouldTerminate()
else parsePreamble(input, 0)
} else parsePreamble(input, 0)
} catch {
case NotEnoughDataException continue(input, 0)((newInput, _) tryParseInitialBoundary(newInput))
}
def parsePreamble(input: ByteString, offset: Int): StateResult =
try {
@tailrec def rec(index: Int): StateResult = {
val needleEnd = boyerMoore.nextIndex(input, index) + needle.length
if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2)
else if (doubleDash(input, needleEnd)) setShouldTerminate()
else rec(needleEnd)
}
rec(offset)
} catch {
case NotEnoughDataException continue(input.takeRight(needle.length + 2), 0)(parsePreamble)
}
@tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](),
headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = {
def contentType =
cth match {
case Some(x) x.contentType
case None defaultContentType
}
var lineEnd = 0
val resultHeader =
try {
if (!boundary(input, lineStart)) {
lineEnd = headerParser.parseHeaderLine(input, lineStart)()
headerParser.resultHeader
} else BoundaryHeader
} catch {
case NotEnoughDataException null
}
resultHeader match {
case null continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth))
case BoundaryHeader
emit(BodyPartStart(headers.toList, _ HttpEntity.empty(contentType)))
val ix = lineStart + boundaryLength
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
else if (doubleDash(input, ix)) setShouldTerminate()
else fail("Illegal multipart boundary in message content")
case EmptyHeader parseEntity(headers.toList, contentType)(input, lineEnd)
case h: `Content-Type`
if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h))
else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth)
else fail("multipart part must not contain more than one Content-Type header")
case h if headerCount < maxHeaderCount
parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth)
case _ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers")
}
}
// work-around for compiler complaining about non-tail-recursion if we inline this method
def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int,
cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult =
parseHeaderLines(input, lineStart, headers, headerCount, cth)
def parseEntity(headers: List[HttpHeader], contentType: ContentType,
emitPartChunk: (List[HttpHeader], ContentType, ByteString) Unit = {
(headers, ct, bytes)
emit(BodyPartStart(headers, entityParts HttpEntity.IndefiniteLength(ct,
entityParts.collect { case EntityPart(data) data })))
emit(bytes)
},
emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) Unit = {
(headers, ct, bytes)
emit(BodyPartStart(headers, { rest
SubSource.kill(rest)
HttpEntity.Strict(ct, bytes)
}))
})(input: ByteString, offset: Int): StateResult =
try {
@tailrec def rec(index: Int): StateResult = {
val currentPartEnd = boyerMoore.nextIndex(input, index)
def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd))
val needleEnd = currentPartEnd + needle.length
if (crlf(input, needleEnd)) {
emitFinalChunk()
parseHeaderLines(input, needleEnd + 2)
} else if (doubleDash(input, needleEnd)) {
emitFinalChunk()
setShouldTerminate()
} else rec(needleEnd)
}
rec(offset)
} catch {
case NotEnoughDataException
// we cannot emit all input bytes since the end of the input might be the start of the next boundary
val emitEnd = input.length - needle.length - 2
if (emitEnd > offset) {
emitPartChunk(headers, contentType, input.slice(offset, emitEnd))
val simpleEmit: (List[HttpHeader], ContentType, ByteString) Unit = (_, _, bytes) emit(bytes)
continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit))
} else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk))
}
def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes))
def emit(element: Output): Unit = output = output.enqueue(element)
def dequeue(): Output = {
val head = output.head
output = output.tail
head
}
def continue(input: ByteString, offset: Int)(next: (ByteString, Int) StateResult): StateResult = {
state =
math.signum(offset - input.length) match {
case -1 more next(input ++ more, offset)
case 0 next(_, 0)
case 1 throw new IllegalStateException
}
done()
}
def continue(next: (ByteString, Int) StateResult): StateResult = {
state = next(_, 0)
done()
}
def fail(summary: String): StateResult = fail(ErrorInfo(summary))
def fail(info: ErrorInfo): StateResult = {
emit(ParseError(info))
setShouldTerminate()
}
def setShouldTerminate(): StateResult = {
shouldTerminate = true
done()
}
def done(): StateResult = null // StateResult is a phantom type
// the length of the needle without the preceding CRLF
def boundaryLength = needle.length - 2
@tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean =
(ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1)
def crlf(input: ByteString, offset: Int): Boolean =
byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n'
def doubleDash(input: ByteString, offset: Int): Boolean =
byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-'
}
def parsePreamble(input: ByteString, offset: Int): StateResult =
try {
@tailrec def rec(index: Int): StateResult = {
val needleEnd = boyerMoore.nextIndex(input, index) + needle.length
if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2)
else if (doubleDash(input, needleEnd)) terminate()
else rec(needleEnd)
}
rec(offset)
} catch {
case NotEnoughDataException continue(input.takeRight(needle.length + 2), 0)(parsePreamble)
}
@tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](),
headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = {
def contentType =
cth match {
case Some(x) x.contentType
case None defaultContentType
}
var lineEnd = 0
val resultHeader =
try {
if (!boundary(input, lineStart)) {
lineEnd = headerParser.parseHeaderLine(input, lineStart)()
headerParser.resultHeader
} else BoundaryHeader
} catch {
case NotEnoughDataException null
}
resultHeader match {
case null continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth))
case BoundaryHeader
emit(BodyPartStart(headers.toList, _ HttpEntity.empty(contentType)))
val ix = lineStart + boundaryLength
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
else if (doubleDash(input, ix)) terminate()
else fail("Illegal multipart boundary in message content")
case EmptyHeader parseEntity(headers.toList, contentType)(input, lineEnd)
case h: `Content-Type`
if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h))
else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth)
else fail("multipart part must not contain more than one Content-Type header")
case h if headerCount < maxHeaderCount
parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth)
case _ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers")
}
}
// work-around for compiler complaining about non-tail-recursion if we inline this method
def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int,
cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult =
parseHeaderLines(input, lineStart, headers, headerCount, cth)
def parseEntity(headers: List[HttpHeader], contentType: ContentType,
emitPartChunk: (List[HttpHeader], ContentType, ByteString) Unit = {
(headers, ct, bytes)
emit(BodyPartStart(headers, entityParts HttpEntity.IndefiniteLength(ct,
entityParts.collect { case EntityPart(data) data })))
emit(bytes)
},
emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) Unit = {
(headers, ct, bytes)
emit(BodyPartStart(headers, { rest
SubSource.kill(rest)
HttpEntity.Strict(ct, bytes)
}))
})(input: ByteString, offset: Int): StateResult =
try {
@tailrec def rec(index: Int): StateResult = {
val currentPartEnd = boyerMoore.nextIndex(input, index)
def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd))
val needleEnd = currentPartEnd + needle.length
if (crlf(input, needleEnd)) {
emitFinalChunk()
parseHeaderLines(input, needleEnd + 2)
} else if (doubleDash(input, needleEnd)) {
emitFinalChunk()
terminate()
} else rec(needleEnd)
}
rec(offset)
} catch {
case NotEnoughDataException
// we cannot emit all input bytes since the end of the input might be the start of the next boundary
val emitEnd = input.length - needle.length - 2
if (emitEnd > offset) {
emitPartChunk(headers, contentType, input.slice(offset, emitEnd))
val simpleEmit: (List[HttpHeader], ContentType, ByteString) Unit = (_, _, bytes) emit(bytes)
continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit))
} else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk))
}
def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes))
def emit(element: Output): Unit = output = output.enqueue(element)
def dequeue(): Output = {
val head = output.head
output = output.tail
head
}
def continue(input: ByteString, offset: Int)(next: (ByteString, Int) StateResult): StateResult = {
state =
math.signum(offset - input.length) match {
case -1 more next(input ++ more, offset)
case 0 next(_, 0)
case 1 throw new IllegalStateException
}
done()
}
def continue(next: (ByteString, Int) StateResult): StateResult = {
state = next(_, 0)
done()
}
def fail(summary: String): StateResult = fail(ErrorInfo(summary))
def fail(info: ErrorInfo): StateResult = {
emit(ParseError(info))
terminate()
}
def terminate(): StateResult = {
terminated = true
done()
}
def done(): StateResult = null // StateResult is a phantom type
// the length of the needle without the preceding CRLF
def boundaryLength = needle.length - 2
@tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean =
(ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1)
def crlf(input: ByteString, offset: Int): Boolean =
byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n'
def doubleDash(input: ByteString, offset: Int): Boolean =
byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-'
}
private[http] object BodyPartParser {

View file

@ -97,7 +97,7 @@ trait MultipartUnmarshallers {
createStrict(mediaType, builder.result())
case _
val bodyParts = entity.dataBytes
.transform(() parser)
.via(parser)
.splitWhen(_.isInstanceOf[PartStart])
.buffer(100, OverflowStrategy.backpressure) // FIXME remove (#19240)
.prefixAndTail(1)

View file

@ -8,21 +8,21 @@ import akka.util.ByteString
import akka.stream.stage._
import akka.stream.Supervision
class IteratorInterpreterSpec extends AkkaSpec {
class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
import Supervision.stoppingDecider
"IteratorInterpreter" must {
"work in the happy case" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
Map((x: Int) x + 1, stoppingDecider))).iterator
Map((x: Int) x + 1, stoppingDecider).toGS)).iterator
itr.toSeq should be(2 to 11)
}
"hasNext should not affect elements" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
Map((x: Int) x, stoppingDecider))).iterator
Map((x: Int) x, stoppingDecider).toGS)).iterator
itr.hasNext should be(true)
itr.hasNext should be(true)
@ -34,27 +34,27 @@ class IteratorInterpreterSpec extends AkkaSpec {
}
"work with ops that need extra pull for complete" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator
itr.toSeq should be(Seq(1))
}
"throw exceptions on empty iterator" in {
val itr = new IteratorInterpreter[Int, Int](List(1).iterator, Seq(
Map((x: Int) x, stoppingDecider))).iterator
Map((x: Int) x, stoppingDecider).toGS)).iterator
itr.next() should be(1)
a[NoSuchElementException] should be thrownBy { itr.next() }
}
"throw exceptions when chain fails" in {
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) ctx.fail(new ArithmeticException())
else ctx.push(elem)
}
})).iterator
val stage = new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) ctx.fail(new ArithmeticException())
else ctx.push(elem)
}
}
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
itr.next() should be(1)
itr.hasNext should be(true)
@ -63,13 +63,13 @@ class IteratorInterpreterSpec extends AkkaSpec {
}
"throw exceptions when op in chain throws" in {
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) throw new ArithmeticException()
else ctx.push(elem)
}
})).iterator
val stage = new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) throw new ArithmeticException()
else ctx.push(elem)
}
}
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
itr.next() should be(1)
itr.hasNext should be(true)
@ -79,7 +79,7 @@ class IteratorInterpreterSpec extends AkkaSpec {
"work with an empty iterator" in {
val itr = new IteratorInterpreter[Int, Int](Iterator.empty, Seq(
Map((x: Int) x + 1, stoppingDecider))).iterator
Map((x: Int) x + 1, stoppingDecider).toGS)).iterator
itr.hasNext should be(false)
a[NoSuchElementException] should be thrownBy { itr.next() }
@ -89,7 +89,8 @@ class IteratorInterpreterSpec extends AkkaSpec {
val testBytes = (1 to 10).map(ByteString(_))
def newItr(threshold: Int) =
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(ByteStringBatcher(threshold))).iterator
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(
ByteStringBatcher(threshold).toGS)).iterator
val itr1 = newItr(20)
itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
@ -108,7 +109,8 @@ class IteratorInterpreterSpec extends AkkaSpec {
itr3.hasNext should be(false)
val itr4 =
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(ByteStringBatcher(10))).iterator
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(
ByteStringBatcher(10).toGS)).iterator
itr4.hasNext should be(false)
}

View file

@ -99,7 +99,10 @@ private[akka] object IteratorInterpreter {
/**
* INTERNAL API
*/
private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[PushPullStage[_, _]]) {
private[akka] class IteratorInterpreter[I, O](
val input: Iterator[I],
val stages: Seq[GraphStageWithMaterializedValue[FlowShape[_, _], Any]]) {
import akka.stream.impl.fusing.IteratorInterpreter._
private val upstream = IteratorUpstream(input)
@ -109,31 +112,30 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S
import GraphInterpreter.Boundary
var i = 0
val length = ops.length
val attributes = Array.fill[Attributes](ops.length)(Attributes.none)
val length = stages.length
val attributes = Array.fill[Attributes](length)(Attributes.none)
val ins = Array.ofDim[Inlet[_]](length + 1)
val inOwners = Array.ofDim[Int](length + 1)
val outs = Array.ofDim[Outlet[_]](length + 1)
val outOwners = Array.ofDim[Int](length + 1)
val stages = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length)
val stagesArray = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length)
ins(ops.length) = null
inOwners(ops.length) = Boundary
ins(length) = null
inOwners(length) = Boundary
outs(0) = null
outOwners(0) = Boundary
val opsIterator = ops.iterator
while (opsIterator.hasNext) {
val op = opsIterator.next().asInstanceOf[Stage[Any, Any]]
val stage = new PushPullGraphStage((_) op, Attributes.none)
stages(i) = stage
val stagesIterator = stages.iterator
while (stagesIterator.hasNext) {
val stage = stagesIterator.next()
stagesArray(i) = stage
ins(i) = stage.shape.in
inOwners(i) = i
outs(i + 1) = stage.shape.out
outOwners(i + 1) = i
i += 1
}
val assembly = new GraphAssembly(stages, attributes, ins, inOwners, outs, outOwners)
val assembly = new GraphAssembly(stagesArray, attributes, ins, inOwners, outs, outOwners)
val (inHandlers, outHandlers, logics) =
assembly.materialize(Attributes.none, assembly.stages.map(_.module), new ju.HashMap, _ ())
@ -148,7 +150,7 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S
fuzzingMode = false,
null)
interpreter.attachUpstreamBoundary(0, upstream)
interpreter.attachDownstreamBoundary(ops.length, downstream)
interpreter.attachDownstreamBoundary(length, downstream)
interpreter.init(null)
}