=str #19361 migrating ByteStringParserStage to graph stage

This commit is contained in:
Alexander Golubev 2016-01-15 18:18:17 -05:00
parent 3fc332d2c9
commit 07c0da36f2
14 changed files with 184 additions and 267 deletions

View file

@ -46,7 +46,7 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
} }
object ReadFrameHeader extends Step { object ReadFrameHeader extends Step {
override def parse(reader: ByteReader): (FrameEvent, Step) = { override def parse(reader: ByteReader): ParseResult[FrameEvent] = {
import Protocol._ import Protocol._
val flagsAndOp = reader.readByte() val flagsAndOp = reader.readByte()
@ -83,23 +83,25 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
val takeNow = (header.length min reader.remainingSize).toInt val takeNow = (header.length min reader.remainingSize).toInt
val thisFrameData = reader.take(takeNow) val thisFrameData = reader.take(takeNow)
val noMoreData = thisFrameData.length == length
val nextState = val nextState =
if (thisFrameData.length == length) ReadFrameHeader if (noMoreData) ReadFrameHeader
else new ReadData(length - thisFrameData.length) else new ReadData(length - thisFrameData.length)
(FrameStart(header, thisFrameData.compact), nextState) ParseResult(Some(FrameStart(header, thisFrameData.compact)), nextState, true)
} }
} }
class ReadData(_remaining: Long) extends Step { class ReadData(_remaining: Long) extends Step {
override def canWorkWithPartialData = true
var remaining = _remaining var remaining = _remaining
override def parse(reader: ByteReader): (FrameEvent, Step) = override def parse(reader: ByteReader): ParseResult[FrameEvent] =
if (reader.remainingSize < remaining) { if (reader.remainingSize < remaining) {
remaining -= reader.remainingSize remaining -= reader.remainingSize
(FrameData(reader.takeAll(), lastPart = false), this) ParseResult(Some(FrameData(reader.takeAll(), lastPart = false)), this, true)
} else { } else {
(FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader) ParseResult(Some(FrameData(reader.take(remaining.toInt), lastPart = true)), ReadFrameHeader, true)
} }
} }
} }

View file

@ -1,51 +0,0 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.util
import scala.util.control.NoStackTrace
import akka.util.ByteString
/**
* A helper class to read from a ByteString statefully.
*
* INTERNAL API
*/
private[akka] class ByteReader(input: ByteString) {
import ByteReader.NeedMoreData
private[this] var off = 0
def hasRemaining: Boolean = off < input.size
def currentOffset: Int = off
def remainingData: ByteString = input.drop(off)
def fromStartToHere: ByteString = input.take(off)
def readByte(): Int =
if (off < input.length) {
val x = input(off)
off += 1
x & 0xFF
} else throw NeedMoreData
def readShortLE(): Int = readByte() | (readByte() << 8)
def readIntLE(): Int = readShortLE() | (readShortLE() << 16)
def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32)
def readShortBE(): Int = (readByte() << 8) | readByte()
def readIntBE(): Int = (readShortBE() << 16) | readShortBE()
def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL)
def skip(numBytes: Int): Unit =
if (off + numBytes <= input.length) off += numBytes
else throw NeedMoreData
def skipZeroTerminatedString(): Unit = while (readByte() != 0) {}
}
/*
* INTERNAL API
*/
private[akka] object ByteReader {
val NeedMoreData = new Exception with NoStackTrace
}

View file

@ -1,59 +0,0 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.util
import akka.stream.stage.{ Context, StatefulStage }
import akka.util.ByteString
import akka.stream.stage.SyncDirective
/**
* A helper class for writing parsers from ByteStrings.
*
* FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529
*
* INTERNAL API
*/
private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] {
protected def onTruncation(ctx: Context[Out]): SyncDirective
/**
* Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of
* `ctx.pull()` to have truncation errors reported.
*/
abstract class IntermediateState extends State {
override def onPull(ctx: Context[Out]): SyncDirective = pull(ctx)
def pull(ctx: Context[Out]): SyncDirective =
if (ctx.isFinishing) onTruncation(ctx)
else ctx.pull()
}
/**
* A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun
* occurs the previous data is saved and the reading process is restarted from the beginning
* once more data was received.
*
* As [[read]] may be called several times for the same prefix of data, make sure not to
* manipulate any state during reading from the ByteReader.
*/
private[akka] trait ByteReadingState extends IntermediateState {
def read(reader: ByteReader, ctx: Context[Out]): SyncDirective
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective =
try {
val reader = new ByteReader(data)
read(reader, ctx)
} catch {
case ByteReader.NeedMoreData
become(TryAgain(data, this))
pull(ctx)
}
}
private case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState {
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective = {
become(byteReadingState)
byteReadingState.onPush(previousData ++ data, ctx)
}
}
}

View file

@ -10,6 +10,7 @@ import akka.NotUsed
import akka.http.scaladsl.model.RequestEntity import akka.http.scaladsl.model.RequestEntity
import akka.stream._ import akka.stream._
import akka.stream.impl.StreamLayout.Module import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule } import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.stage._ import akka.stream.stage._
@ -114,40 +115,46 @@ private[http] object StreamUtils {
Flow[ByteString].transform(() transformer).named("sliceBytes") Flow[ByteString].transform(() transformer).named("sliceBytes")
} }
def limitByteChunksStage(maxBytesPerChunk: Int): PushPullStage[ByteString, ByteString] = def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] =
new StatefulStage[ByteString, ByteString] { new SimpleLinearGraphStage[ByteString] {
def initial = WaitingForData override def initialAttributes = Attributes.name("limitByteChunksStage")
var remaining = ByteString.empty
case object WaitingForData extends State { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
if (elem.size <= maxBytesPerChunk) ctx.push(elem)
else {
become(DeliveringData(elem.drop(maxBytesPerChunk)))
ctx.push(elem.take(maxBytesPerChunk))
}
}
case class DeliveringData(remaining: ByteString) extends State { def splitAndPush(elem: ByteString): Unit = {
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
throw new IllegalStateException("Not expecting data")
override def onPull(ctx: Context[ByteString]): SyncDirective = {
val toPush = remaining.take(maxBytesPerChunk) val toPush = remaining.take(maxBytesPerChunk)
val toKeep = remaining.drop(maxBytesPerChunk) val toKeep = remaining.drop(maxBytesPerChunk)
push(out, toPush)
remaining = toKeep
}
setHandlers(in, out, WaitingForData)
become { case object WaitingForData extends InHandler with OutHandler {
if (toKeep.isEmpty) WaitingForData override def onPush(): Unit = {
else DeliveringData(toKeep) val elem = grab(in)
if (elem.size <= maxBytesPerChunk) push(out, elem)
else {
splitAndPush(elem)
setHandlers(in, out, DeliveringData)
} }
if (ctx.isFinishing) ctx.pushAndFinish(toPush)
else ctx.push(toPush)
} }
override def onPull(): Unit = pull(in)
} }
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = case object DeliveringData extends InHandler() with OutHandler {
current match { var finishing = false
case WaitingForData ctx.finish() override def onPush(): Unit = throw new IllegalStateException("Not expecting data")
case _: DeliveringData ctx.absorbTermination() override def onPull(): Unit = {
splitAndPush(remaining)
if (remaining.isEmpty) {
if (finishing) completeStage() else setHandlers(in, out, WaitingForData)
}
}
override def onUpstreamFinish(): Unit = if (remaining.isEmpty) completeStage() else finishing = true
}
override def toString = "limitByteChunksStage"
} }
} }

View file

@ -4,10 +4,13 @@
package akka.http.scaladsl.coding package akka.http.scaladsl.coding
import akka.stream.{ Attributes, FlowShape }
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import scala.concurrent.duration._ import scala.concurrent.duration._
import org.scalatest.WordSpec import org.scalatest.WordSpec
import akka.util.ByteString import akka.util.ByteString
import akka.stream.stage.{ SyncDirective, Context, PushStage, Stage } import akka.stream.stage._
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.http.impl.util._ import akka.http.impl.util._
import headers._ import headers._
@ -34,10 +37,17 @@ class DecoderSpec extends WordSpec with CodecSpecSupport {
case object DummyDecoder extends StreamDecoder { case object DummyDecoder extends StreamDecoder {
val encoding = HttpEncodings.compress val encoding = HttpEncodings.compress
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] = override def newDecompressorStage(maxBytesPerChunk: Int): () GraphStage[FlowShape[ByteString, ByteString]] =
() new PushStage[ByteString, ByteString] { () new SimpleLinearGraphStage[ByteString] {
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
ctx.push(elem ++ ByteString("compressed")) setHandler(in, new InHandler {
override def onPush(): Unit = push(out, grab(in) ++ ByteString("compressed"))
})
setHandler(out, new OutHandler {
override def onPull(): Unit = pull(in)
})
} }
} }
} }
}

View file

@ -33,7 +33,6 @@ class GzipSpec extends CoderSpec {
} }
"throw an error if compressed data is just missing the trailer at the end" in { "throw an error if compressed data is just missing the trailer at the end" in {
def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8")) def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8"))
val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl")) val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl"))
ex.getCause.getMessage should equal("Truncated GZIP stream") ex.getCause.getMessage should equal("Truncated GZIP stream")
} }

View file

@ -6,8 +6,8 @@ package akka.http.scaladsl.coding
import akka.NotUsed import akka.NotUsed
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.stream.Materializer import akka.stream.{ FlowShape, Materializer }
import akka.stream.stage.Stage import akka.stream.stage.{ GraphStage, Stage }
import akka.util.ByteString import akka.util.ByteString
import headers.HttpEncoding import headers.HttpEncoding
import akka.stream.scaladsl.{ Sink, Source, Flow } import akka.stream.scaladsl.{ Sink, Source, Flow }
@ -37,7 +37,7 @@ object Decoder {
/** A decoder that is implemented in terms of a [[Stage]] */ /** A decoder that is implemented in terms of a [[Stage]] */
trait StreamDecoder extends Decoder { outer trait StreamDecoder extends Decoder { outer
protected def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] protected def newDecompressorStage(maxBytesPerChunk: Int): () GraphStage[FlowShape[ByteString, ByteString]]
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder = def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
@ -45,11 +45,11 @@ trait StreamDecoder extends Decoder { outer ⇒
def encoding: HttpEncoding = outer.encoding def encoding: HttpEncoding = outer.encoding
override def maxBytesPerChunk: Int = newMaxBytesPerChunk override def maxBytesPerChunk: Int = newMaxBytesPerChunk
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] = def newDecompressorStage(maxBytesPerChunk: Int): () GraphStage[FlowShape[ByteString, ByteString]] =
outer.newDecompressorStage(maxBytesPerChunk) outer.newDecompressorStage(maxBytesPerChunk)
} }
def decoderFlow: Flow[ByteString, ByteString, NotUsed] = def decoderFlow: Flow[ByteString, ByteString, NotUsed] =
Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk)) Flow.fromGraph(newDecompressorStage(maxBytesPerChunk)())
} }

View file

@ -5,11 +5,12 @@
package akka.http.scaladsl.coding package akka.http.scaladsl.coding
import java.util.zip.{ Inflater, Deflater } import java.util.zip.{ Inflater, Deflater }
import akka.stream.stage._ import akka.stream.Attributes
import akka.stream.io.ByteStringParser
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
import akka.util.{ ByteStringBuilder, ByteString } import akka.util.{ ByteStringBuilder, ByteString }
import scala.annotation.tailrec import scala.annotation.tailrec
import akka.http.impl.util._
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.HttpEncodings import akka.http.scaladsl.model.headers.HttpEncodings
@ -86,56 +87,49 @@ private[http] object DeflateCompressor {
} }
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
protected def createInflater() = new Inflater()
def initial: State = StartInflate override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
def afterInflate: State = StartInflate override val inflater: Inflater = new Inflater()
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {} override val inflateState = new Inflate(true) {
protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.finish() override def onTruncation(): Unit = completeStage()
} }
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] { override def afterInflate = inflateState
protected def createInflater(): Inflater override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
val inflater = createInflater()
protected def afterInflate: State startWith(inflateState)
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
/** Start inflating */
case object StartInflate extends IntermediateState {
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = {
require(inflater.needsInput())
inflater.setInput(data.toArray)
becomeWithRemaining(Inflate()(data), ByteString.empty, ctx)
} }
} }
/** Inflate */ abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
case class Inflate()(data: ByteString) extends IntermediateState { extends ByteStringParser[ByteString] {
override def onPull(ctx: Context[ByteString]): SyncDirective = {
abstract class DecompressorParsingLogic extends ParsingLogic {
val inflater: Inflater
def afterInflate: ParseStep[ByteString]
def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
val inflateState: Inflate
abstract class Inflate(noPostProcessing: Boolean) extends ParseStep[ByteString] {
override def canWorkWithPartialData = true
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
inflater.setInput(reader.remainingData.toArray)
val buffer = new Array[Byte](maxBytesPerChunk) val buffer = new Array[Byte](maxBytesPerChunk)
val read = inflater.inflate(buffer) val read = inflater.inflate(buffer)
reader.skip(reader.remainingSize - inflater.getRemaining)
if (read > 0) { if (read > 0) {
afterBytesRead(buffer, 0, read) afterBytesRead(buffer, 0, read)
ctx.push(ByteString.fromArray(buffer, 0, read)) val next = if (inflater.finished()) afterInflate else this
ParseResult(Some(ByteString.fromArray(buffer, 0, read)), next, noPostProcessing)
} else { } else {
val remaining = data.takeRight(inflater.getRemaining) if (inflater.finished()) ParseResult(None, afterInflate, noPostProcessing)
val next = else throw ByteStringParser.NeedMoreData
if (inflater.finished()) afterInflate }
else StartInflate
becomeWithRemaining(next, remaining, ctx)
} }
} }
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
throw new IllegalStateException("Don't expect a new Element")
}
def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = {
become(next)
if (remaining.isEmpty) current.onPull(ctx)
else current.onPush(remaining, ctx)
} }
} }

View file

@ -4,14 +4,15 @@
package akka.http.scaladsl.coding package akka.http.scaladsl.coding
import akka.util.ByteString import java.util.zip.{ CRC32, Deflater, Inflater, ZipException }
import akka.stream.stage._
import akka.http.impl.util.ByteReader
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
import akka.http.impl.engine.ws.{ ProtocolException, FrameEvent }
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import headers.HttpEncodings import akka.http.scaladsl.model.headers.HttpEncodings
import akka.stream.Attributes
import akka.stream.io.ByteStringParser
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
import akka.util.ByteString
class Gzip(val messageFilter: HttpMessage Boolean) extends Coder with StreamDecoder { class Gzip(val messageFilter: HttpMessage Boolean) extends Coder with StreamDecoder {
val encoding = HttpEncodings.gzip val encoding = HttpEncodings.gzip
@ -60,30 +61,22 @@ class GzipCompressor extends DeflateCompressor {
} }
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
protected def createInflater(): Inflater = new Inflater(true) override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
override val inflater: Inflater = new Inflater(true)
override def afterInflate: ParseStep[ByteString] = ReadTrailer
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
crc32.update(buffer, offset, length)
def initial: State = Initial trait Step extends ParseStep[ByteString] {
override def onTruncation(): Unit = failStage(new ZipException("Truncated GZIP stream"))
/** No bytes were received yet */
case object Initial extends State {
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective =
if (data.isEmpty) ctx.pull()
else becomeWithRemaining(ReadHeaders, data, ctx)
override def onPull(ctx: Context[ByteString]): SyncDirective =
if (ctx.isFinishing) {
ctx.finish()
} else super.onPull(ctx)
} }
override val inflateState = new Inflate(false) with Step
var crc32: CRC32 = new CRC32 startWith(ReadHeaders)
protected def afterInflate: State = ReadTrailer
/** Reading the header bytes */ /** Reading the header bytes */
case object ReadHeaders extends ByteReadingState { case object ReadHeaders extends Step {
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
import reader._ import reader._
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
val flags = readByte() val flags = readByte()
@ -95,36 +88,28 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
inflater.reset() inflater.reset()
crc32.reset() crc32.reset()
becomeWithRemaining(StartInflate, remainingData, ctx) ParseResult(None, inflateState, false)
} }
} }
var crc32: CRC32 = new CRC32
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = private def fail(msg: String) = throw new ZipException(msg)
crc32.update(buffer, offset, length)
/** Reading the trailer */ /** Reading the trailer */
case object ReadTrailer extends ByteReadingState { case object ReadTrailer extends Step {
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
import reader._ import reader._
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ )
fail("Corrupt GZIP trailer ISIZE")
becomeWithRemaining(Initial, remainingData, ctx) ParseResult(None, ReadHeaders, true)
}
} }
} }
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
private def crc16(data: ByteString) = { private def crc16(data: ByteString) = {
val crc = new CRC32 val crc = new CRC32
crc.update(data.toArray) crc.update(data.toArray)
crc.getValue.toInt & 0xFFFF crc.getValue.toInt & 0xFFFF
} }
override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream"))
private def fail(msg: String) = throw new ZipException(msg)
} }
/** INTERNAL API */ /** INTERNAL API */

View file

@ -6,7 +6,8 @@ package akka.http.scaladsl.coding
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.http.impl.util.StreamUtils import akka.http.impl.util.StreamUtils
import akka.stream.stage.Stage import akka.stream.FlowShape
import akka.stream.stage.{ GraphStage, Stage }
import akka.util.ByteString import akka.util.ByteString
import headers.HttpEncodings import headers.HttpEncodings
@ -25,7 +26,7 @@ object NoCoding extends Coder with StreamDecoder {
def newCompressor = NoCodingCompressor def newCompressor = NoCodingCompressor
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] = def newDecompressorStage(maxBytesPerChunk: Int): () GraphStage[FlowShape[ByteString, ByteString]] =
() StreamUtils.limitByteChunksStage(maxBytesPerChunk) () StreamUtils.limitByteChunksStage(maxBytesPerChunk)
} }

View file

@ -43,7 +43,7 @@ object GraphStages {
/** /**
* INERNAL API * INERNAL API
*/ */
private[stream] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] { private[akka] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T](Logging.simpleName(this) + ".in") val in = Inlet[T](Logging.simpleName(this) + ".in")
val out = Outlet[T](Logging.simpleName(this) + ".out") val out = Outlet[T](Logging.simpleName(this) + ".out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)

View file

@ -19,30 +19,35 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
final override val shape = FlowShape(bytesIn, objOut) final override val shape = FlowShape(bytesIn, objOut)
class ParsingLogic extends GraphStageLogic(shape) { class ParsingLogic extends GraphStageLogic(shape) {
var pullOnParserRequest = false
override def preStart(): Unit = pull(bytesIn) override def preStart(): Unit = pull(bytesIn)
setHandler(objOut, eagerTerminateOutput) setHandler(objOut, eagerTerminateOutput)
private var buffer = ByteString.empty private var buffer = ByteString.empty
private var current: ParseStep[T] = FinishedParser private var current: ParseStep[T] = FinishedParser
private var acceptUpstreamFinish: Boolean = true
final protected def startWith(step: ParseStep[T]): Unit = current = step final protected def startWith(step: ParseStep[T]): Unit = current = step
@tailrec private def doParse(): Unit = @tailrec private def doParse(): Unit =
if (buffer.nonEmpty) { if (buffer.nonEmpty) {
val cont = try {
val reader = new ByteReader(buffer) val reader = new ByteReader(buffer)
val (elem, next) = current.parse(reader) val cont = try {
emit(objOut, elem) val parseResult = current.parse(reader)
if (next == FinishedParser) { acceptUpstreamFinish = parseResult.acceptUpstreamFinish
parseResult.result.map(emit(objOut, _))
if (parseResult.nextStep == FinishedParser) {
completeStage() completeStage()
false false
} else { } else {
buffer = reader.remainingData buffer = reader.remainingData
current = next current = parseResult.nextStep
true true
} }
} catch { } catch {
case NeedMoreData case NeedMoreData
acceptUpstreamFinish = false
if (current.canWorkWithPartialData) buffer = reader.remainingData
pull(bytesIn) pull(bytesIn)
false false
} }
@ -51,11 +56,12 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
setHandler(bytesIn, new InHandler { setHandler(bytesIn, new InHandler {
override def onPush(): Unit = { override def onPush(): Unit = {
pullOnParserRequest = false
buffer ++= grab(bytesIn) buffer ++= grab(bytesIn)
doParse() doParse()
} }
override def onUpstreamFinish(): Unit = override def onUpstreamFinish(): Unit =
if (buffer.isEmpty) completeStage() if (buffer.isEmpty && acceptUpstreamFinish) completeStage()
else current.onTruncation() else current.onTruncation()
}) })
} }
@ -63,13 +69,28 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
object ByteStringParser { object ByteStringParser {
/**
* @param result - parser can return some element for downstream or return None if no element was generated
* @param nextStep - next parser
* @param acceptUpstreamFinish - if true - stream will complete when received `onUpstreamFinish`, if "false"
* - onTruncation will be called
*/
case class ParseResult[+T](result: Option[T],
nextStep: ParseStep[T],
acceptUpstreamFinish: Boolean = true)
trait ParseStep[+T] { trait ParseStep[+T] {
def parse(reader: ByteReader): (T, ParseStep[T]) /**
* Must return true when NeedMoreData will clean buffer. If returns false - next pulled
* data will be appended to existing data in buffer
*/
def canWorkWithPartialData: Boolean = false
def parse(reader: ByteReader): ParseResult[T]
def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser") def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser")
} }
object FinishedParser extends ParseStep[Nothing] { object FinishedParser extends ParseStep[Nothing] {
def parse(reader: ByteReader) = override def parse(reader: ByteReader) =
throw new IllegalStateException("no initial parser installed: you must use startWith(...)") throw new IllegalStateException("no initial parser installed: you must use startWith(...)")
} }
@ -83,6 +104,7 @@ object ByteStringParser {
def remainingSize: Int = input.size - off def remainingSize: Int = input.size - off
def currentOffset: Int = off def currentOffset: Int = off
def remainingData: ByteString = input.drop(off) def remainingData: ByteString = input.drop(off)
def fromStartToHere: ByteString = input.take(off) def fromStartToHere: ByteString = input.take(off)

View file

@ -257,7 +257,6 @@ object Zip {
* '''Cancels when''' any downstream cancels * '''Cancels when''' any downstream cancels
*/ */
object Unzip { object Unzip {
import akka.japi.function.Function
/** /**
* Creates a new `Unzip` stage with the specified output types. * Creates a new `Unzip` stage with the specified output types.

View file

@ -300,6 +300,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
if (_interpreter != null) _interpreter.setHandler(conn(in), handler) if (_interpreter != null) _interpreter.setHandler(conn(in), handler)
} }
/**
* Assign callbacks for linear stage for both [[Inlet]] and [[Outlet]]
*/
final protected def setHandlers(in: Inlet[_], out: Outlet[_], handler: InHandler with OutHandler): Unit = {
setHandler(in, handler)
setHandler(out, handler)
}
/** /**
* Retrieves the current callback for the events on the given [[Inlet]] * Retrieves the current callback for the events on the given [[Inlet]]
*/ */