Migrate BodyPartParser to GraphStage #20288.
This commit is contained in:
parent
9af39f97ba
commit
c8352c04c4
4 changed files with 245 additions and 226 deletions
|
|
@ -5,6 +5,7 @@
|
|||
package akka.http.impl.engine.parsing
|
||||
|
||||
import akka.NotUsed
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import akka.event.LoggingAdapter
|
||||
import akka.parboiled2.CharPredicate
|
||||
|
|
@ -13,7 +14,9 @@ import akka.stream.stage._
|
|||
import akka.util.ByteString
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util._
|
||||
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
|
||||
import headers._
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import akka.stream.impl.fusing.SubSource
|
||||
|
||||
|
|
@ -26,7 +29,7 @@ private[http] final class BodyPartParser(defaultContentType: ContentType,
|
|||
boundary: String,
|
||||
log: LoggingAdapter,
|
||||
settings: BodyPartParser.Settings)
|
||||
extends PushPullStage[ByteString, BodyPartParser.Output] {
|
||||
extends GraphStage[FlowShape[ByteString, BodyPartParser.Output]] {
|
||||
import BodyPartParser._
|
||||
import settings._
|
||||
|
||||
|
|
@ -57,202 +60,214 @@ private[http] final class BodyPartParser(defaultContentType: ContentType,
|
|||
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
|
||||
}
|
||||
|
||||
private[this] var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful
|
||||
private[this] var state: ByteString ⇒ StateResult = tryParseInitialBoundary
|
||||
private[this] var terminated = false
|
||||
val in = Inlet[ByteString]("BodyPartParser.in")
|
||||
val out = Outlet[BodyPartParser.Output]("BodyPartParser.out")
|
||||
|
||||
def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit =
|
||||
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def onPush(input: ByteString, ctx: Context[Output]): SyncDirective =
|
||||
if (!terminated) {
|
||||
try state(input)
|
||||
catch {
|
||||
case e: ParsingException ⇒ fail(e.info)
|
||||
case NotEnoughDataException ⇒
|
||||
// we are missing a try/catch{continue} wrapper somewhere
|
||||
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
|
||||
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful
|
||||
private var state: ByteString => StateResult = tryParseInitialBoundary
|
||||
private var shouldTerminate = false
|
||||
|
||||
override def onPush(): Unit = {
|
||||
if (!shouldTerminate) {
|
||||
val elem = grab(in)
|
||||
try state(elem)
|
||||
catch {
|
||||
case e: ParsingException => fail(e.info)
|
||||
case NotEnoughDataException =>
|
||||
// we are missing a try/catch{continue} wrapper somewhere
|
||||
throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException)
|
||||
}
|
||||
if (output.nonEmpty) push(out, dequeue())
|
||||
else if (!shouldTerminate) pull(in)
|
||||
else completeStage()
|
||||
} else completeStage()
|
||||
}
|
||||
if (output.nonEmpty) ctx.push(dequeue())
|
||||
else if (!terminated) ctx.pull()
|
||||
else ctx.finish()
|
||||
} else ctx.finish()
|
||||
|
||||
override def onPull(ctx: Context[Output]): SyncDirective = {
|
||||
if (output.nonEmpty)
|
||||
ctx.push(dequeue())
|
||||
else if (ctx.isFinishing) {
|
||||
if (terminated) ctx.finish()
|
||||
else ctx.pushAndFinish(ParseError(ErrorInfo("Unexpected end of multipart entity")))
|
||||
} else
|
||||
ctx.pull()
|
||||
}
|
||||
override def onPull(): Unit = {
|
||||
if (output.nonEmpty) push(out, dequeue())
|
||||
else if (isClosed(in)) {
|
||||
if (!shouldTerminate) push(out, ParseError(ErrorInfo("Unexpected end of multipart entity")))
|
||||
completeStage()
|
||||
} else pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective = ctx.absorbTermination()
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (isAvailable(out)) onPull()
|
||||
}
|
||||
|
||||
def tryParseInitialBoundary(input: ByteString): StateResult =
|
||||
// we don't use boyerMoore here because we are testing for the boundary *without* a
|
||||
// preceding CRLF and at a known location (the very beginning of the entity)
|
||||
try {
|
||||
if (boundary(input, 0)) {
|
||||
val ix = boundaryLength
|
||||
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
|
||||
else if (doubleDash(input, ix)) terminate()
|
||||
else parsePreamble(input, 0)
|
||||
} else parsePreamble(input, 0)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ continue(input, 0)((newInput, _) ⇒ tryParseInitialBoundary(newInput))
|
||||
setHandlers(in, out, this)
|
||||
|
||||
def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit =
|
||||
if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty)
|
||||
|
||||
def tryParseInitialBoundary(input: ByteString): StateResult =
|
||||
// we don't use boyerMoore here because we are testing for the boundary *without* a
|
||||
// preceding CRLF and at a known location (the very beginning of the entity)
|
||||
try {
|
||||
if (boundary(input, 0)) {
|
||||
val ix = boundaryLength
|
||||
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
|
||||
else if (doubleDash(input, ix)) setShouldTerminate()
|
||||
else parsePreamble(input, 0)
|
||||
} else parsePreamble(input, 0)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ continue(input, 0)((newInput, _) ⇒ tryParseInitialBoundary(newInput))
|
||||
}
|
||||
|
||||
def parsePreamble(input: ByteString, offset: Int): StateResult =
|
||||
try {
|
||||
@tailrec def rec(index: Int): StateResult = {
|
||||
val needleEnd = boyerMoore.nextIndex(input, index) + needle.length
|
||||
if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2)
|
||||
else if (doubleDash(input, needleEnd)) setShouldTerminate()
|
||||
else rec(needleEnd)
|
||||
}
|
||||
rec(offset)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ continue(input.takeRight(needle.length + 2), 0)(parsePreamble)
|
||||
}
|
||||
|
||||
@tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](),
|
||||
headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = {
|
||||
def contentType =
|
||||
cth match {
|
||||
case Some(x) ⇒ x.contentType
|
||||
case None ⇒ defaultContentType
|
||||
}
|
||||
|
||||
var lineEnd = 0
|
||||
val resultHeader =
|
||||
try {
|
||||
if (!boundary(input, lineStart)) {
|
||||
lineEnd = headerParser.parseHeaderLine(input, lineStart)()
|
||||
headerParser.resultHeader
|
||||
} else BoundaryHeader
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ null
|
||||
}
|
||||
resultHeader match {
|
||||
case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth))
|
||||
|
||||
case BoundaryHeader ⇒
|
||||
emit(BodyPartStart(headers.toList, _ ⇒ HttpEntity.empty(contentType)))
|
||||
val ix = lineStart + boundaryLength
|
||||
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
|
||||
else if (doubleDash(input, ix)) setShouldTerminate()
|
||||
else fail("Illegal multipart boundary in message content")
|
||||
|
||||
case EmptyHeader ⇒ parseEntity(headers.toList, contentType)(input, lineEnd)
|
||||
|
||||
case h: `Content-Type` ⇒
|
||||
if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h))
|
||||
else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth)
|
||||
else fail("multipart part must not contain more than one Content-Type header")
|
||||
|
||||
case h if headerCount < maxHeaderCount ⇒
|
||||
parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth)
|
||||
|
||||
case _ ⇒ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers")
|
||||
}
|
||||
}
|
||||
|
||||
// work-around for compiler complaining about non-tail-recursion if we inline this method
|
||||
def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int,
|
||||
cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult =
|
||||
parseHeaderLines(input, lineStart, headers, headerCount, cth)
|
||||
|
||||
def parseEntity(headers: List[HttpHeader], contentType: ContentType,
|
||||
emitPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = {
|
||||
(headers, ct, bytes) ⇒
|
||||
emit(BodyPartStart(headers, entityParts ⇒ HttpEntity.IndefiniteLength(ct,
|
||||
entityParts.collect { case EntityPart(data) ⇒ data })))
|
||||
emit(bytes)
|
||||
},
|
||||
emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = {
|
||||
(headers, ct, bytes) ⇒
|
||||
emit(BodyPartStart(headers, { rest ⇒
|
||||
SubSource.kill(rest)
|
||||
HttpEntity.Strict(ct, bytes)
|
||||
}))
|
||||
})(input: ByteString, offset: Int): StateResult =
|
||||
try {
|
||||
@tailrec def rec(index: Int): StateResult = {
|
||||
val currentPartEnd = boyerMoore.nextIndex(input, index)
|
||||
def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd))
|
||||
val needleEnd = currentPartEnd + needle.length
|
||||
if (crlf(input, needleEnd)) {
|
||||
emitFinalChunk()
|
||||
parseHeaderLines(input, needleEnd + 2)
|
||||
} else if (doubleDash(input, needleEnd)) {
|
||||
emitFinalChunk()
|
||||
setShouldTerminate()
|
||||
} else rec(needleEnd)
|
||||
}
|
||||
rec(offset)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒
|
||||
// we cannot emit all input bytes since the end of the input might be the start of the next boundary
|
||||
val emitEnd = input.length - needle.length - 2
|
||||
if (emitEnd > offset) {
|
||||
emitPartChunk(headers, contentType, input.slice(offset, emitEnd))
|
||||
val simpleEmit: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = (_, _, bytes) ⇒ emit(bytes)
|
||||
continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit))
|
||||
} else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk))
|
||||
}
|
||||
|
||||
def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes))
|
||||
|
||||
def emit(element: Output): Unit = output = output.enqueue(element)
|
||||
|
||||
def dequeue(): Output = {
|
||||
val head = output.head
|
||||
output = output.tail
|
||||
head
|
||||
}
|
||||
|
||||
def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = {
|
||||
state =
|
||||
math.signum(offset - input.length) match {
|
||||
case -1 ⇒ more ⇒ next(input ++ more, offset)
|
||||
case 0 ⇒ next(_, 0)
|
||||
case 1 ⇒ throw new IllegalStateException
|
||||
}
|
||||
done()
|
||||
}
|
||||
|
||||
def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = {
|
||||
state = next(_, 0)
|
||||
done()
|
||||
}
|
||||
|
||||
def fail(summary: String): StateResult = fail(ErrorInfo(summary))
|
||||
|
||||
def fail(info: ErrorInfo): StateResult = {
|
||||
emit(ParseError(info))
|
||||
setShouldTerminate()
|
||||
}
|
||||
|
||||
def setShouldTerminate(): StateResult = {
|
||||
shouldTerminate = true
|
||||
done()
|
||||
}
|
||||
|
||||
def done(): StateResult = null // StateResult is a phantom type
|
||||
|
||||
// the length of the needle without the preceding CRLF
|
||||
def boundaryLength = needle.length - 2
|
||||
|
||||
@tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean =
|
||||
(ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1)
|
||||
|
||||
def crlf(input: ByteString, offset: Int): Boolean =
|
||||
byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n'
|
||||
|
||||
def doubleDash(input: ByteString, offset: Int): Boolean =
|
||||
byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-'
|
||||
}
|
||||
|
||||
def parsePreamble(input: ByteString, offset: Int): StateResult =
|
||||
try {
|
||||
@tailrec def rec(index: Int): StateResult = {
|
||||
val needleEnd = boyerMoore.nextIndex(input, index) + needle.length
|
||||
if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2)
|
||||
else if (doubleDash(input, needleEnd)) terminate()
|
||||
else rec(needleEnd)
|
||||
}
|
||||
rec(offset)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ continue(input.takeRight(needle.length + 2), 0)(parsePreamble)
|
||||
}
|
||||
|
||||
@tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](),
|
||||
headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = {
|
||||
def contentType =
|
||||
cth match {
|
||||
case Some(x) ⇒ x.contentType
|
||||
case None ⇒ defaultContentType
|
||||
}
|
||||
|
||||
var lineEnd = 0
|
||||
val resultHeader =
|
||||
try {
|
||||
if (!boundary(input, lineStart)) {
|
||||
lineEnd = headerParser.parseHeaderLine(input, lineStart)()
|
||||
headerParser.resultHeader
|
||||
} else BoundaryHeader
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒ null
|
||||
}
|
||||
resultHeader match {
|
||||
case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth))
|
||||
|
||||
case BoundaryHeader ⇒
|
||||
emit(BodyPartStart(headers.toList, _ ⇒ HttpEntity.empty(contentType)))
|
||||
val ix = lineStart + boundaryLength
|
||||
if (crlf(input, ix)) parseHeaderLines(input, ix + 2)
|
||||
else if (doubleDash(input, ix)) terminate()
|
||||
else fail("Illegal multipart boundary in message content")
|
||||
|
||||
case EmptyHeader ⇒ parseEntity(headers.toList, contentType)(input, lineEnd)
|
||||
|
||||
case h: `Content-Type` ⇒
|
||||
if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h))
|
||||
else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth)
|
||||
else fail("multipart part must not contain more than one Content-Type header")
|
||||
|
||||
case h if headerCount < maxHeaderCount ⇒
|
||||
parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth)
|
||||
|
||||
case _ ⇒ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers")
|
||||
}
|
||||
}
|
||||
|
||||
// work-around for compiler complaining about non-tail-recursion if we inline this method
|
||||
def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int,
|
||||
cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult =
|
||||
parseHeaderLines(input, lineStart, headers, headerCount, cth)
|
||||
|
||||
def parseEntity(headers: List[HttpHeader], contentType: ContentType,
|
||||
emitPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = {
|
||||
(headers, ct, bytes) ⇒
|
||||
emit(BodyPartStart(headers, entityParts ⇒ HttpEntity.IndefiniteLength(ct,
|
||||
entityParts.collect { case EntityPart(data) ⇒ data })))
|
||||
emit(bytes)
|
||||
},
|
||||
emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = {
|
||||
(headers, ct, bytes) ⇒
|
||||
emit(BodyPartStart(headers, { rest ⇒
|
||||
SubSource.kill(rest)
|
||||
HttpEntity.Strict(ct, bytes)
|
||||
}))
|
||||
})(input: ByteString, offset: Int): StateResult =
|
||||
try {
|
||||
@tailrec def rec(index: Int): StateResult = {
|
||||
val currentPartEnd = boyerMoore.nextIndex(input, index)
|
||||
def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd))
|
||||
val needleEnd = currentPartEnd + needle.length
|
||||
if (crlf(input, needleEnd)) {
|
||||
emitFinalChunk()
|
||||
parseHeaderLines(input, needleEnd + 2)
|
||||
} else if (doubleDash(input, needleEnd)) {
|
||||
emitFinalChunk()
|
||||
terminate()
|
||||
} else rec(needleEnd)
|
||||
}
|
||||
rec(offset)
|
||||
} catch {
|
||||
case NotEnoughDataException ⇒
|
||||
// we cannot emit all input bytes since the end of the input might be the start of the next boundary
|
||||
val emitEnd = input.length - needle.length - 2
|
||||
if (emitEnd > offset) {
|
||||
emitPartChunk(headers, contentType, input.slice(offset, emitEnd))
|
||||
val simpleEmit: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = (_, _, bytes) ⇒ emit(bytes)
|
||||
continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit))
|
||||
} else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk))
|
||||
}
|
||||
|
||||
def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes))
|
||||
|
||||
def emit(element: Output): Unit = output = output.enqueue(element)
|
||||
|
||||
def dequeue(): Output = {
|
||||
val head = output.head
|
||||
output = output.tail
|
||||
head
|
||||
}
|
||||
|
||||
def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = {
|
||||
state =
|
||||
math.signum(offset - input.length) match {
|
||||
case -1 ⇒ more ⇒ next(input ++ more, offset)
|
||||
case 0 ⇒ next(_, 0)
|
||||
case 1 ⇒ throw new IllegalStateException
|
||||
}
|
||||
done()
|
||||
}
|
||||
|
||||
def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = {
|
||||
state = next(_, 0)
|
||||
done()
|
||||
}
|
||||
|
||||
def fail(summary: String): StateResult = fail(ErrorInfo(summary))
|
||||
def fail(info: ErrorInfo): StateResult = {
|
||||
emit(ParseError(info))
|
||||
terminate()
|
||||
}
|
||||
|
||||
def terminate(): StateResult = {
|
||||
terminated = true
|
||||
done()
|
||||
}
|
||||
|
||||
def done(): StateResult = null // StateResult is a phantom type
|
||||
|
||||
// the length of the needle without the preceding CRLF
|
||||
def boundaryLength = needle.length - 2
|
||||
|
||||
@tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean =
|
||||
(ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1)
|
||||
|
||||
def crlf(input: ByteString, offset: Int): Boolean =
|
||||
byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n'
|
||||
|
||||
def doubleDash(input: ByteString, offset: Int): Boolean =
|
||||
byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-'
|
||||
|
||||
}
|
||||
|
||||
private[http] object BodyPartParser {
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ trait MultipartUnmarshallers {
|
|||
createStrict(mediaType, builder.result())
|
||||
case _ ⇒
|
||||
val bodyParts = entity.dataBytes
|
||||
.transform(() ⇒ parser)
|
||||
.via(parser)
|
||||
.splitWhen(_.isInstanceOf[PartStart])
|
||||
.buffer(100, OverflowStrategy.backpressure) // FIXME remove (#19240)
|
||||
.prefixAndTail(1)
|
||||
|
|
|
|||
|
|
@ -8,21 +8,21 @@ import akka.util.ByteString
|
|||
import akka.stream.stage._
|
||||
import akka.stream.Supervision
|
||||
|
||||
class IteratorInterpreterSpec extends AkkaSpec {
|
||||
class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
||||
import Supervision.stoppingDecider
|
||||
|
||||
"IteratorInterpreter" must {
|
||||
|
||||
"work in the happy case" in {
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
|
||||
Map((x: Int) ⇒ x + 1, stoppingDecider))).iterator
|
||||
Map((x: Int) ⇒ x + 1, stoppingDecider).toGS)).iterator
|
||||
|
||||
itr.toSeq should be(2 to 11)
|
||||
}
|
||||
|
||||
"hasNext should not affect elements" in {
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
|
||||
Map((x: Int) ⇒ x, stoppingDecider))).iterator
|
||||
Map((x: Int) ⇒ x, stoppingDecider).toGS)).iterator
|
||||
|
||||
itr.hasNext should be(true)
|
||||
itr.hasNext should be(true)
|
||||
|
|
@ -34,27 +34,27 @@ class IteratorInterpreterSpec extends AkkaSpec {
|
|||
}
|
||||
|
||||
"work with ops that need extra pull for complete" in {
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator
|
||||
|
||||
itr.toSeq should be(Seq(1))
|
||||
}
|
||||
|
||||
"throw exceptions on empty iterator" in {
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1).iterator, Seq(
|
||||
Map((x: Int) ⇒ x, stoppingDecider))).iterator
|
||||
Map((x: Int) ⇒ x, stoppingDecider).toGS)).iterator
|
||||
|
||||
itr.next() should be(1)
|
||||
a[NoSuchElementException] should be thrownBy { itr.next() }
|
||||
}
|
||||
|
||||
"throw exceptions when chain fails" in {
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
|
||||
new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) ctx.fail(new ArithmeticException())
|
||||
else ctx.push(elem)
|
||||
}
|
||||
})).iterator
|
||||
val stage = new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) ctx.fail(new ArithmeticException())
|
||||
else ctx.push(elem)
|
||||
}
|
||||
}
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
|
||||
|
||||
itr.next() should be(1)
|
||||
itr.hasNext should be(true)
|
||||
|
|
@ -63,13 +63,13 @@ class IteratorInterpreterSpec extends AkkaSpec {
|
|||
}
|
||||
|
||||
"throw exceptions when op in chain throws" in {
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
|
||||
new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) throw new ArithmeticException()
|
||||
else ctx.push(elem)
|
||||
}
|
||||
})).iterator
|
||||
val stage = new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) throw new ArithmeticException()
|
||||
else ctx.push(elem)
|
||||
}
|
||||
}
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
|
||||
|
||||
itr.next() should be(1)
|
||||
itr.hasNext should be(true)
|
||||
|
|
@ -79,7 +79,7 @@ class IteratorInterpreterSpec extends AkkaSpec {
|
|||
|
||||
"work with an empty iterator" in {
|
||||
val itr = new IteratorInterpreter[Int, Int](Iterator.empty, Seq(
|
||||
Map((x: Int) ⇒ x + 1, stoppingDecider))).iterator
|
||||
Map((x: Int) ⇒ x + 1, stoppingDecider).toGS)).iterator
|
||||
|
||||
itr.hasNext should be(false)
|
||||
a[NoSuchElementException] should be thrownBy { itr.next() }
|
||||
|
|
@ -89,7 +89,8 @@ class IteratorInterpreterSpec extends AkkaSpec {
|
|||
val testBytes = (1 to 10).map(ByteString(_))
|
||||
|
||||
def newItr(threshold: Int) =
|
||||
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(ByteStringBatcher(threshold))).iterator
|
||||
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(
|
||||
ByteStringBatcher(threshold).toGS)).iterator
|
||||
|
||||
val itr1 = newItr(20)
|
||||
itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
|
||||
|
|
@ -108,7 +109,8 @@ class IteratorInterpreterSpec extends AkkaSpec {
|
|||
itr3.hasNext should be(false)
|
||||
|
||||
val itr4 =
|
||||
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(ByteStringBatcher(10))).iterator
|
||||
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(
|
||||
ByteStringBatcher(10).toGS)).iterator
|
||||
|
||||
itr4.hasNext should be(false)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,7 +99,10 @@ private[akka] object IteratorInterpreter {
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[PushPullStage[_, _]]) {
|
||||
private[akka] class IteratorInterpreter[I, O](
|
||||
val input: Iterator[I],
|
||||
val stages: Seq[GraphStageWithMaterializedValue[FlowShape[_, _], Any]]) {
|
||||
|
||||
import akka.stream.impl.fusing.IteratorInterpreter._
|
||||
|
||||
private val upstream = IteratorUpstream(input)
|
||||
|
|
@ -109,31 +112,30 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S
|
|||
import GraphInterpreter.Boundary
|
||||
|
||||
var i = 0
|
||||
val length = ops.length
|
||||
val attributes = Array.fill[Attributes](ops.length)(Attributes.none)
|
||||
val length = stages.length
|
||||
val attributes = Array.fill[Attributes](length)(Attributes.none)
|
||||
val ins = Array.ofDim[Inlet[_]](length + 1)
|
||||
val inOwners = Array.ofDim[Int](length + 1)
|
||||
val outs = Array.ofDim[Outlet[_]](length + 1)
|
||||
val outOwners = Array.ofDim[Int](length + 1)
|
||||
val stages = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length)
|
||||
val stagesArray = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length)
|
||||
|
||||
ins(ops.length) = null
|
||||
inOwners(ops.length) = Boundary
|
||||
ins(length) = null
|
||||
inOwners(length) = Boundary
|
||||
outs(0) = null
|
||||
outOwners(0) = Boundary
|
||||
|
||||
val opsIterator = ops.iterator
|
||||
while (opsIterator.hasNext) {
|
||||
val op = opsIterator.next().asInstanceOf[Stage[Any, Any]]
|
||||
val stage = new PushPullGraphStage((_) ⇒ op, Attributes.none)
|
||||
stages(i) = stage
|
||||
val stagesIterator = stages.iterator
|
||||
while (stagesIterator.hasNext) {
|
||||
val stage = stagesIterator.next()
|
||||
stagesArray(i) = stage
|
||||
ins(i) = stage.shape.in
|
||||
inOwners(i) = i
|
||||
outs(i + 1) = stage.shape.out
|
||||
outOwners(i + 1) = i
|
||||
i += 1
|
||||
}
|
||||
val assembly = new GraphAssembly(stages, attributes, ins, inOwners, outs, outOwners)
|
||||
val assembly = new GraphAssembly(stagesArray, attributes, ins, inOwners, outs, outOwners)
|
||||
|
||||
val (inHandlers, outHandlers, logics) =
|
||||
assembly.materialize(Attributes.none, assembly.stages.map(_.module), new ju.HashMap, _ ⇒ ())
|
||||
|
|
@ -148,7 +150,7 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S
|
|||
fuzzingMode = false,
|
||||
null)
|
||||
interpreter.attachUpstreamBoundary(0, upstream)
|
||||
interpreter.attachDownstreamBoundary(ops.length, downstream)
|
||||
interpreter.attachDownstreamBoundary(length, downstream)
|
||||
interpreter.init(null)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue