diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala index 65f2106d71..342e5183e8 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/BodyPartParser.scala @@ -5,6 +5,7 @@ package akka.http.impl.engine.parsing import akka.NotUsed + import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.parboiled2.CharPredicate @@ -13,7 +14,9 @@ import akka.stream.stage._ import akka.util.ByteString import akka.http.scaladsl.model._ import akka.http.impl.util._ +import akka.stream.{ Attributes, FlowShape, Inlet, Outlet } import headers._ + import scala.collection.mutable.ListBuffer import akka.stream.impl.fusing.SubSource @@ -26,7 +29,7 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, boundary: String, log: LoggingAdapter, settings: BodyPartParser.Settings) - extends PushPullStage[ByteString, BodyPartParser.Output] { + extends GraphStage[FlowShape[ByteString, BodyPartParser.Output]] { import BodyPartParser._ import settings._ @@ -57,202 +60,214 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty) } - private[this] var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful - private[this] var state: ByteString ⇒ StateResult = tryParseInitialBoundary - private[this] var terminated = false + val in = Inlet[ByteString]("BodyPartParser.in") + val out = Outlet[BodyPartParser.Output]("BodyPartParser.out") - def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit = - if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty) + override val shape = FlowShape(in, out) - override def onPush(input: ByteString, ctx: Context[Output]): SyncDirective = - if (!terminated) { - try state(input) - catch { - case e: ParsingException ⇒ fail(e.info) - case NotEnoughDataException ⇒ - // we are missing a try/catch{continue} wrapper somewhere - throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) + override def createLogic(attributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private var output = collection.immutable.Queue.empty[Output] // FIXME this probably is too wasteful + private var state: ByteString => StateResult = tryParseInitialBoundary + private var shouldTerminate = false + + override def onPush(): Unit = { + if (!shouldTerminate) { + val elem = grab(in) + try state(elem) + catch { + case e: ParsingException => fail(e.info) + case NotEnoughDataException => + // we are missing a try/catch{continue} wrapper somewhere + throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) + } + if (output.nonEmpty) push(out, dequeue()) + else if (!shouldTerminate) pull(in) + else completeStage() + } else completeStage() } - if (output.nonEmpty) ctx.push(dequeue()) - else if (!terminated) ctx.pull() - else ctx.finish() - } else ctx.finish() - override def onPull(ctx: Context[Output]): SyncDirective = { - if (output.nonEmpty) - ctx.push(dequeue()) - else if (ctx.isFinishing) { - if (terminated) ctx.finish() - else ctx.pushAndFinish(ParseError(ErrorInfo("Unexpected end of multipart entity"))) - } else - ctx.pull() - } + override def onPull(): Unit = { + if (output.nonEmpty) push(out, dequeue()) + else if (isClosed(in)) { + if (!shouldTerminate) push(out, ParseError(ErrorInfo("Unexpected end of multipart entity"))) + completeStage() + } else pull(in) + } - override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective = ctx.absorbTermination() + override def onUpstreamFinish(): Unit = { + if (isAvailable(out)) onPull() + } - def tryParseInitialBoundary(input: ByteString): StateResult = - // we don't use boyerMoore here because we are testing for the boundary *without* a - // preceding CRLF and at a known location (the very beginning of the entity) - try { - if (boundary(input, 0)) { - val ix = boundaryLength - if (crlf(input, ix)) parseHeaderLines(input, ix + 2) - else if (doubleDash(input, ix)) terminate() - else parsePreamble(input, 0) - } else parsePreamble(input, 0) - } catch { - case NotEnoughDataException ⇒ continue(input, 0)((newInput, _) ⇒ tryParseInitialBoundary(newInput)) + setHandlers(in, out, this) + + def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit = + if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty) + + def tryParseInitialBoundary(input: ByteString): StateResult = + // we don't use boyerMoore here because we are testing for the boundary *without* a + // preceding CRLF and at a known location (the very beginning of the entity) + try { + if (boundary(input, 0)) { + val ix = boundaryLength + if (crlf(input, ix)) parseHeaderLines(input, ix + 2) + else if (doubleDash(input, ix)) setShouldTerminate() + else parsePreamble(input, 0) + } else parsePreamble(input, 0) + } catch { + case NotEnoughDataException ⇒ continue(input, 0)((newInput, _) ⇒ tryParseInitialBoundary(newInput)) + } + + def parsePreamble(input: ByteString, offset: Int): StateResult = + try { + @tailrec def rec(index: Int): StateResult = { + val needleEnd = boyerMoore.nextIndex(input, index) + needle.length + if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2) + else if (doubleDash(input, needleEnd)) setShouldTerminate() + else rec(needleEnd) + } + rec(offset) + } catch { + case NotEnoughDataException ⇒ continue(input.takeRight(needle.length + 2), 0)(parsePreamble) + } + + @tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](), + headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = { + def contentType = + cth match { + case Some(x) ⇒ x.contentType + case None ⇒ defaultContentType + } + + var lineEnd = 0 + val resultHeader = + try { + if (!boundary(input, lineStart)) { + lineEnd = headerParser.parseHeaderLine(input, lineStart)() + headerParser.resultHeader + } else BoundaryHeader + } catch { + case NotEnoughDataException ⇒ null + } + resultHeader match { + case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth)) + + case BoundaryHeader ⇒ + emit(BodyPartStart(headers.toList, _ ⇒ HttpEntity.empty(contentType))) + val ix = lineStart + boundaryLength + if (crlf(input, ix)) parseHeaderLines(input, ix + 2) + else if (doubleDash(input, ix)) setShouldTerminate() + else fail("Illegal multipart boundary in message content") + + case EmptyHeader ⇒ parseEntity(headers.toList, contentType)(input, lineEnd) + + case h: `Content-Type` ⇒ + if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h)) + else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth) + else fail("multipart part must not contain more than one Content-Type header") + + case h if headerCount < maxHeaderCount ⇒ + parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth) + + case _ ⇒ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers") + } + } + + // work-around for compiler complaining about non-tail-recursion if we inline this method + def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int, + cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult = + parseHeaderLines(input, lineStart, headers, headerCount, cth) + + def parseEntity(headers: List[HttpHeader], contentType: ContentType, + emitPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = { + (headers, ct, bytes) ⇒ + emit(BodyPartStart(headers, entityParts ⇒ HttpEntity.IndefiniteLength(ct, + entityParts.collect { case EntityPart(data) ⇒ data }))) + emit(bytes) + }, + emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = { + (headers, ct, bytes) ⇒ + emit(BodyPartStart(headers, { rest ⇒ + SubSource.kill(rest) + HttpEntity.Strict(ct, bytes) + })) + })(input: ByteString, offset: Int): StateResult = + try { + @tailrec def rec(index: Int): StateResult = { + val currentPartEnd = boyerMoore.nextIndex(input, index) + def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd)) + val needleEnd = currentPartEnd + needle.length + if (crlf(input, needleEnd)) { + emitFinalChunk() + parseHeaderLines(input, needleEnd + 2) + } else if (doubleDash(input, needleEnd)) { + emitFinalChunk() + setShouldTerminate() + } else rec(needleEnd) + } + rec(offset) + } catch { + case NotEnoughDataException ⇒ + // we cannot emit all input bytes since the end of the input might be the start of the next boundary + val emitEnd = input.length - needle.length - 2 + if (emitEnd > offset) { + emitPartChunk(headers, contentType, input.slice(offset, emitEnd)) + val simpleEmit: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = (_, _, bytes) ⇒ emit(bytes) + continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit)) + } else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk)) + } + + def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes)) + + def emit(element: Output): Unit = output = output.enqueue(element) + + def dequeue(): Output = { + val head = output.head + output = output.tail + head + } + + def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = { + state = + math.signum(offset - input.length) match { + case -1 ⇒ more ⇒ next(input ++ more, offset) + case 0 ⇒ next(_, 0) + case 1 ⇒ throw new IllegalStateException + } + done() + } + + def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = { + state = next(_, 0) + done() + } + + def fail(summary: String): StateResult = fail(ErrorInfo(summary)) + + def fail(info: ErrorInfo): StateResult = { + emit(ParseError(info)) + setShouldTerminate() + } + + def setShouldTerminate(): StateResult = { + shouldTerminate = true + done() + } + + def done(): StateResult = null // StateResult is a phantom type + + // the length of the needle without the preceding CRLF + def boundaryLength = needle.length - 2 + + @tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean = + (ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1) + + def crlf(input: ByteString, offset: Int): Boolean = + byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n' + + def doubleDash(input: ByteString, offset: Int): Boolean = + byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-' } - - def parsePreamble(input: ByteString, offset: Int): StateResult = - try { - @tailrec def rec(index: Int): StateResult = { - val needleEnd = boyerMoore.nextIndex(input, index) + needle.length - if (crlf(input, needleEnd)) parseHeaderLines(input, needleEnd + 2) - else if (doubleDash(input, needleEnd)) terminate() - else rec(needleEnd) - } - rec(offset) - } catch { - case NotEnoughDataException ⇒ continue(input.takeRight(needle.length + 2), 0)(parsePreamble) - } - - @tailrec def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](), - headerCount: Int = 0, cth: Option[`Content-Type`] = None): StateResult = { - def contentType = - cth match { - case Some(x) ⇒ x.contentType - case None ⇒ defaultContentType - } - - var lineEnd = 0 - val resultHeader = - try { - if (!boundary(input, lineStart)) { - lineEnd = headerParser.parseHeaderLine(input, lineStart)() - headerParser.resultHeader - } else BoundaryHeader - } catch { - case NotEnoughDataException ⇒ null - } - resultHeader match { - case null ⇒ continue(input, lineStart)(parseHeaderLinesAux(headers, headerCount, cth)) - - case BoundaryHeader ⇒ - emit(BodyPartStart(headers.toList, _ ⇒ HttpEntity.empty(contentType))) - val ix = lineStart + boundaryLength - if (crlf(input, ix)) parseHeaderLines(input, ix + 2) - else if (doubleDash(input, ix)) terminate() - else fail("Illegal multipart boundary in message content") - - case EmptyHeader ⇒ parseEntity(headers.toList, contentType)(input, lineEnd) - - case h: `Content-Type` ⇒ - if (cth.isEmpty) parseHeaderLines(input, lineEnd, headers, headerCount + 1, Some(h)) - else if (cth.get == h) parseHeaderLines(input, lineEnd, headers, headerCount, cth) - else fail("multipart part must not contain more than one Content-Type header") - - case h if headerCount < maxHeaderCount ⇒ - parseHeaderLines(input, lineEnd, headers += h, headerCount + 1, cth) - - case _ ⇒ fail(s"multipart part contains more than the configured limit of $maxHeaderCount headers") - } - } - - // work-around for compiler complaining about non-tail-recursion if we inline this method - def parseHeaderLinesAux(headers: ListBuffer[HttpHeader], headerCount: Int, - cth: Option[`Content-Type`])(input: ByteString, lineStart: Int): StateResult = - parseHeaderLines(input, lineStart, headers, headerCount, cth) - - def parseEntity(headers: List[HttpHeader], contentType: ContentType, - emitPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = { - (headers, ct, bytes) ⇒ - emit(BodyPartStart(headers, entityParts ⇒ HttpEntity.IndefiniteLength(ct, - entityParts.collect { case EntityPart(data) ⇒ data }))) - emit(bytes) - }, - emitFinalPartChunk: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = { - (headers, ct, bytes) ⇒ - emit(BodyPartStart(headers, { rest ⇒ - SubSource.kill(rest) - HttpEntity.Strict(ct, bytes) - })) - })(input: ByteString, offset: Int): StateResult = - try { - @tailrec def rec(index: Int): StateResult = { - val currentPartEnd = boyerMoore.nextIndex(input, index) - def emitFinalChunk() = emitFinalPartChunk(headers, contentType, input.slice(offset, currentPartEnd)) - val needleEnd = currentPartEnd + needle.length - if (crlf(input, needleEnd)) { - emitFinalChunk() - parseHeaderLines(input, needleEnd + 2) - } else if (doubleDash(input, needleEnd)) { - emitFinalChunk() - terminate() - } else rec(needleEnd) - } - rec(offset) - } catch { - case NotEnoughDataException ⇒ - // we cannot emit all input bytes since the end of the input might be the start of the next boundary - val emitEnd = input.length - needle.length - 2 - if (emitEnd > offset) { - emitPartChunk(headers, contentType, input.slice(offset, emitEnd)) - val simpleEmit: (List[HttpHeader], ContentType, ByteString) ⇒ Unit = (_, _, bytes) ⇒ emit(bytes) - continue(input drop emitEnd, 0)(parseEntity(null, null, simpleEmit, simpleEmit)) - } else continue(input, offset)(parseEntity(headers, contentType, emitPartChunk, emitFinalPartChunk)) - } - - def emit(bytes: ByteString): Unit = if (bytes.nonEmpty) emit(EntityPart(bytes)) - - def emit(element: Output): Unit = output = output.enqueue(element) - - def dequeue(): Output = { - val head = output.head - output = output.tail - head - } - - def continue(input: ByteString, offset: Int)(next: (ByteString, Int) ⇒ StateResult): StateResult = { - state = - math.signum(offset - input.length) match { - case -1 ⇒ more ⇒ next(input ++ more, offset) - case 0 ⇒ next(_, 0) - case 1 ⇒ throw new IllegalStateException - } - done() - } - - def continue(next: (ByteString, Int) ⇒ StateResult): StateResult = { - state = next(_, 0) - done() - } - - def fail(summary: String): StateResult = fail(ErrorInfo(summary)) - def fail(info: ErrorInfo): StateResult = { - emit(ParseError(info)) - terminate() - } - - def terminate(): StateResult = { - terminated = true - done() - } - - def done(): StateResult = null // StateResult is a phantom type - - // the length of the needle without the preceding CRLF - def boundaryLength = needle.length - 2 - - @tailrec def boundary(input: ByteString, offset: Int, ix: Int = 2): Boolean = - (ix == needle.length) || (byteAt(input, offset + ix - 2) == needle(ix)) && boundary(input, offset, ix + 1) - - def crlf(input: ByteString, offset: Int): Boolean = - byteChar(input, offset) == '\r' && byteChar(input, offset + 1) == '\n' - - def doubleDash(input: ByteString, offset: Int): Boolean = - byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-' - } private[http] object BodyPartParser { diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala index 38d072a91a..21429636ba 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala @@ -97,7 +97,7 @@ trait MultipartUnmarshallers { createStrict(mediaType, builder.result()) case _ ⇒ val bodyParts = entity.dataBytes - .transform(() ⇒ parser) + .via(parser) .splitWhen(_.isInstanceOf[PartStart]) .buffer(100, OverflowStrategy.backpressure) // FIXME remove (#19240) .prefixAndTail(1) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala index 6fccb27488..6363c34c37 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala @@ -8,21 +8,21 @@ import akka.util.ByteString import akka.stream.stage._ import akka.stream.Supervision -class IteratorInterpreterSpec extends AkkaSpec { +class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider "IteratorInterpreter" must { "work in the happy case" in { val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq( - Map((x: Int) ⇒ x + 1, stoppingDecider))).iterator + Map((x: Int) ⇒ x + 1, stoppingDecider).toGS)).iterator itr.toSeq should be(2 to 11) } "hasNext should not affect elements" in { val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq( - Map((x: Int) ⇒ x, stoppingDecider))).iterator + Map((x: Int) ⇒ x, stoppingDecider).toGS)).iterator itr.hasNext should be(true) itr.hasNext should be(true) @@ -34,27 +34,27 @@ class IteratorInterpreterSpec extends AkkaSpec { } "work with ops that need extra pull for complete" in { - val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator + val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator itr.toSeq should be(Seq(1)) } "throw exceptions on empty iterator" in { val itr = new IteratorInterpreter[Int, Int](List(1).iterator, Seq( - Map((x: Int) ⇒ x, stoppingDecider))).iterator + Map((x: Int) ⇒ x, stoppingDecider).toGS)).iterator itr.next() should be(1) a[NoSuchElementException] should be thrownBy { itr.next() } } "throw exceptions when chain fails" in { - val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( - new PushStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - if (elem == 2) ctx.fail(new ArithmeticException()) - else ctx.push(elem) - } - })).iterator + val stage = new PushStage[Int, Int] { + override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { + if (elem == 2) ctx.fail(new ArithmeticException()) + else ctx.push(elem) + } + } + val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator itr.next() should be(1) itr.hasNext should be(true) @@ -63,13 +63,13 @@ class IteratorInterpreterSpec extends AkkaSpec { } "throw exceptions when op in chain throws" in { - val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( - new PushStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - if (elem == 2) throw new ArithmeticException() - else ctx.push(elem) - } - })).iterator + val stage = new PushStage[Int, Int] { + override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { + if (elem == 2) throw new ArithmeticException() + else ctx.push(elem) + } + } + val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator itr.next() should be(1) itr.hasNext should be(true) @@ -79,7 +79,7 @@ class IteratorInterpreterSpec extends AkkaSpec { "work with an empty iterator" in { val itr = new IteratorInterpreter[Int, Int](Iterator.empty, Seq( - Map((x: Int) ⇒ x + 1, stoppingDecider))).iterator + Map((x: Int) ⇒ x + 1, stoppingDecider).toGS)).iterator itr.hasNext should be(false) a[NoSuchElementException] should be thrownBy { itr.next() } @@ -89,7 +89,8 @@ class IteratorInterpreterSpec extends AkkaSpec { val testBytes = (1 to 10).map(ByteString(_)) def newItr(threshold: Int) = - new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(ByteStringBatcher(threshold))).iterator + new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq( + ByteStringBatcher(threshold).toGS)).iterator val itr1 = newItr(20) itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) @@ -108,7 +109,8 @@ class IteratorInterpreterSpec extends AkkaSpec { itr3.hasNext should be(false) val itr4 = - new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(ByteStringBatcher(10))).iterator + new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq( + ByteStringBatcher(10).toGS)).iterator itr4.hasNext should be(false) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala index acd43f87c1..31be808ca3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -99,7 +99,10 @@ private[akka] object IteratorInterpreter { /** * INTERNAL API */ -private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[PushPullStage[_, _]]) { +private[akka] class IteratorInterpreter[I, O]( + val input: Iterator[I], + val stages: Seq[GraphStageWithMaterializedValue[FlowShape[_, _], Any]]) { + import akka.stream.impl.fusing.IteratorInterpreter._ private val upstream = IteratorUpstream(input) @@ -109,31 +112,30 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S import GraphInterpreter.Boundary var i = 0 - val length = ops.length - val attributes = Array.fill[Attributes](ops.length)(Attributes.none) + val length = stages.length + val attributes = Array.fill[Attributes](length)(Attributes.none) val ins = Array.ofDim[Inlet[_]](length + 1) val inOwners = Array.ofDim[Int](length + 1) val outs = Array.ofDim[Outlet[_]](length + 1) val outOwners = Array.ofDim[Int](length + 1) - val stages = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length) + val stagesArray = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length) - ins(ops.length) = null - inOwners(ops.length) = Boundary + ins(length) = null + inOwners(length) = Boundary outs(0) = null outOwners(0) = Boundary - val opsIterator = ops.iterator - while (opsIterator.hasNext) { - val op = opsIterator.next().asInstanceOf[Stage[Any, Any]] - val stage = new PushPullGraphStage((_) ⇒ op, Attributes.none) - stages(i) = stage + val stagesIterator = stages.iterator + while (stagesIterator.hasNext) { + val stage = stagesIterator.next() + stagesArray(i) = stage ins(i) = stage.shape.in inOwners(i) = i outs(i + 1) = stage.shape.out outOwners(i + 1) = i i += 1 } - val assembly = new GraphAssembly(stages, attributes, ins, inOwners, outs, outOwners) + val assembly = new GraphAssembly(stagesArray, attributes, ins, inOwners, outs, outOwners) val (inHandlers, outHandlers, logics) = assembly.materialize(Attributes.none, assembly.stages.map(_.module), new ju.HashMap, _ ⇒ ()) @@ -148,7 +150,7 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S fuzzingMode = false, null) interpreter.attachUpstreamBoundary(0, upstream) - interpreter.attachDownstreamBoundary(ops.length, downstream) + interpreter.attachDownstreamBoundary(length, downstream) interpreter.init(null) }