diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala index ceaf4a7afe..eec365cbf8 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala @@ -46,7 +46,7 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] { } object ReadFrameHeader extends Step { - override def parse(reader: ByteReader): (FrameEvent, Step) = { + override def parse(reader: ByteReader): ParseResult[FrameEvent] = { import Protocol._ val flagsAndOp = reader.readByte() @@ -83,23 +83,25 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] { val takeNow = (header.length min reader.remainingSize).toInt val thisFrameData = reader.take(takeNow) + val noMoreData = thisFrameData.length == length val nextState = - if (thisFrameData.length == length) ReadFrameHeader + if (noMoreData) ReadFrameHeader else new ReadData(length - thisFrameData.length) - (FrameStart(header, thisFrameData.compact), nextState) + ParseResult(Some(FrameStart(header, thisFrameData.compact)), nextState, true) } } class ReadData(_remaining: Long) extends Step { + override def canWorkWithPartialData = true var remaining = _remaining - override def parse(reader: ByteReader): (FrameEvent, Step) = + override def parse(reader: ByteReader): ParseResult[FrameEvent] = if (reader.remainingSize < remaining) { remaining -= reader.remainingSize - (FrameData(reader.takeAll(), lastPart = false), this) + ParseResult(Some(FrameData(reader.takeAll(), lastPart = false)), this, true) } else { - (FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader) + ParseResult(Some(FrameData(reader.take(remaining.toInt), lastPart = true)), ReadFrameHeader, true) } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala b/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala deleted file mode 100644 index 3a03d3dea5..0000000000 --- a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2009-2014 Typesafe Inc. - */ - -package akka.http.impl.util - -import scala.util.control.NoStackTrace -import akka.util.ByteString - -/** - * A helper class to read from a ByteString statefully. - * - * INTERNAL API - */ -private[akka] class ByteReader(input: ByteString) { - import ByteReader.NeedMoreData - - private[this] var off = 0 - - def hasRemaining: Boolean = off < input.size - - def currentOffset: Int = off - def remainingData: ByteString = input.drop(off) - def fromStartToHere: ByteString = input.take(off) - - def readByte(): Int = - if (off < input.length) { - val x = input(off) - off += 1 - x & 0xFF - } else throw NeedMoreData - def readShortLE(): Int = readByte() | (readByte() << 8) - def readIntLE(): Int = readShortLE() | (readShortLE() << 16) - def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) - - def readShortBE(): Int = (readByte() << 8) | readByte() - def readIntBE(): Int = (readShortBE() << 16) | readShortBE() - def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL) - - def skip(numBytes: Int): Unit = - if (off + numBytes <= input.length) off += numBytes - else throw NeedMoreData - def skipZeroTerminatedString(): Unit = while (readByte() != 0) {} -} - -/* -* INTERNAL API -*/ -private[akka] object ByteReader { - val NeedMoreData = new Exception with NoStackTrace -} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/ByteStringParserStage.scala b/akka-http-core/src/main/scala/akka/http/impl/util/ByteStringParserStage.scala deleted file mode 100644 index 5ae861883a..0000000000 --- a/akka-http-core/src/main/scala/akka/http/impl/util/ByteStringParserStage.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (C) 2009-2014 Typesafe Inc. - */ - -package akka.http.impl.util - -import akka.stream.stage.{ Context, StatefulStage } -import akka.util.ByteString -import akka.stream.stage.SyncDirective - -/** - * A helper class for writing parsers from ByteStrings. - * - * FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529 - * - * INTERNAL API - */ -private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] { - protected def onTruncation(ctx: Context[Out]): SyncDirective - - /** - * Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of - * `ctx.pull()` to have truncation errors reported. - */ - abstract class IntermediateState extends State { - override def onPull(ctx: Context[Out]): SyncDirective = pull(ctx) - def pull(ctx: Context[Out]): SyncDirective = - if (ctx.isFinishing) onTruncation(ctx) - else ctx.pull() - } - - /** - * A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun - * occurs the previous data is saved and the reading process is restarted from the beginning - * once more data was received. - * - * As [[read]] may be called several times for the same prefix of data, make sure not to - * manipulate any state during reading from the ByteReader. - */ - private[akka] trait ByteReadingState extends IntermediateState { - def read(reader: ByteReader, ctx: Context[Out]): SyncDirective - - def onPush(data: ByteString, ctx: Context[Out]): SyncDirective = - try { - val reader = new ByteReader(data) - read(reader, ctx) - } catch { - case ByteReader.NeedMoreData ⇒ - become(TryAgain(data, this)) - pull(ctx) - } - } - private case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState { - def onPush(data: ByteString, ctx: Context[Out]): SyncDirective = { - become(byteReadingState) - byteReadingState.onPush(previousData ++ data, ctx) - } - } -} diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala index 997e9f5e9c..ea4847b9fc 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala @@ -10,6 +10,7 @@ import akka.NotUsed import akka.http.scaladsl.model.RequestEntity import akka.stream._ import akka.stream.impl.StreamLayout.Module +import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule } import akka.stream.scaladsl._ import akka.stream.stage._ @@ -114,41 +115,47 @@ private[http] object StreamUtils { Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes") } - def limitByteChunksStage(maxBytesPerChunk: Int): PushPullStage[ByteString, ByteString] = - new StatefulStage[ByteString, ByteString] { - def initial = WaitingForData + def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] = + new SimpleLinearGraphStage[ByteString] { + override def initialAttributes = Attributes.name("limitByteChunksStage") + var remaining = ByteString.empty - case object WaitingForData extends State { - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = - if (elem.size <= maxBytesPerChunk) ctx.push(elem) - else { - become(DeliveringData(elem.drop(maxBytesPerChunk))) - ctx.push(elem.take(maxBytesPerChunk)) - } - } + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - case class DeliveringData(remaining: ByteString) extends State { - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = - throw new IllegalStateException("Not expecting data") - - override def onPull(ctx: Context[ByteString]): SyncDirective = { + def splitAndPush(elem: ByteString): Unit = { val toPush = remaining.take(maxBytesPerChunk) val toKeep = remaining.drop(maxBytesPerChunk) + push(out, toPush) + remaining = toKeep + } + setHandlers(in, out, WaitingForData) - become { - if (toKeep.isEmpty) WaitingForData - else DeliveringData(toKeep) + case object WaitingForData extends InHandler with OutHandler { + override def onPush(): Unit = { + val elem = grab(in) + if (elem.size <= maxBytesPerChunk) push(out, elem) + else { + splitAndPush(elem) + setHandlers(in, out, DeliveringData) + } } - if (ctx.isFinishing) ctx.pushAndFinish(toPush) - else ctx.push(toPush) + override def onPull(): Unit = pull(in) } - } - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = - current match { - case WaitingForData ⇒ ctx.finish() - case _: DeliveringData ⇒ ctx.absorbTermination() + case object DeliveringData extends InHandler() with OutHandler { + var finishing = false + override def onPush(): Unit = throw new IllegalStateException("Not expecting data") + override def onPull(): Unit = { + splitAndPush(remaining) + if (remaining.isEmpty) { + if (finishing) completeStage() else setHandlers(in, out, WaitingForData) + } + } + override def onUpstreamFinish(): Unit = if (remaining.isEmpty) completeStage() else finishing = true } + + override def toString = "limitByteChunksStage" + } } def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity = diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala index 8b3dd3d102..0181f79bea 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala @@ -4,10 +4,13 @@ package akka.http.scaladsl.coding +import akka.stream.{ Attributes, FlowShape } +import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage + import scala.concurrent.duration._ import org.scalatest.WordSpec import akka.util.ByteString -import akka.stream.stage.{ SyncDirective, Context, PushStage, Stage } +import akka.stream.stage._ import akka.http.scaladsl.model._ import akka.http.impl.util._ import headers._ @@ -34,10 +37,17 @@ class DecoderSpec extends WordSpec with CodecSpecSupport { case object DummyDecoder extends StreamDecoder { val encoding = HttpEncodings.compress - def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = - () ⇒ new PushStage[ByteString, ByteString] { - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = - ctx.push(elem ++ ByteString("compressed")) + override def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] = + () ⇒ new SimpleLinearGraphStage[ByteString] { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in) ++ ByteString("compressed")) + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + } } } + } diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala index 914f834f31..77275f2634 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala @@ -33,7 +33,6 @@ class GzipSpec extends CoderSpec { } "throw an error if compressed data is just missing the trailer at the end" in { def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8")) - val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl")) ex.getCause.getMessage should equal("Truncated GZIP stream") } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala index 423796f13d..90b3557e3c 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala @@ -6,8 +6,8 @@ package akka.http.scaladsl.coding import akka.NotUsed import akka.http.scaladsl.model._ -import akka.stream.Materializer -import akka.stream.stage.Stage +import akka.stream.{ FlowShape, Materializer } +import akka.stream.stage.{ GraphStage, Stage } import akka.util.ByteString import headers.HttpEncoding import akka.stream.scaladsl.{ Sink, Source, Flow } @@ -37,7 +37,7 @@ object Decoder { /** A decoder that is implemented in terms of a [[Stage]] */ trait StreamDecoder extends Decoder { outer ⇒ - protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] + protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder = @@ -45,11 +45,11 @@ trait StreamDecoder extends Decoder { outer ⇒ def encoding: HttpEncoding = outer.encoding override def maxBytesPerChunk: Int = newMaxBytesPerChunk - def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] = outer.newDecompressorStage(maxBytesPerChunk) } def decoderFlow: Flow[ByteString, ByteString, NotUsed] = - Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk)) + Flow.fromGraph(newDecompressorStage(maxBytesPerChunk)()) } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala index e576499b95..cb48863b6a 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala @@ -5,11 +5,12 @@ package akka.http.scaladsl.coding import java.util.zip.{ Inflater, Deflater } -import akka.stream.stage._ +import akka.stream.Attributes +import akka.stream.io.ByteStringParser +import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep } import akka.util.{ ByteStringBuilder, ByteString } import scala.annotation.tailrec -import akka.http.impl.util._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.HttpEncodings @@ -86,56 +87,49 @@ private[http] object DeflateCompressor { } class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { - protected def createInflater() = new Inflater() - def initial: State = StartInflate - def afterInflate: State = StartInflate + override def createLogic(attr: Attributes) = new DecompressorParsingLogic { + override val inflater: Inflater = new Inflater() - protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {} - protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.finish() + override val inflateState = new Inflate(true) { + override def onTruncation(): Unit = completeStage() + } + + override def afterInflate = inflateState + override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {} + + startWith(inflateState) + } } -abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] { - protected def createInflater(): Inflater - val inflater = createInflater() +abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) + extends ByteStringParser[ByteString] { - protected def afterInflate: State - protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit + abstract class DecompressorParsingLogic extends ParsingLogic { + val inflater: Inflater + def afterInflate: ParseStep[ByteString] + def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit + val inflateState: Inflate - /** Start inflating */ - case object StartInflate extends IntermediateState { - def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = { - require(inflater.needsInput()) - inflater.setInput(data.toArray) + abstract class Inflate(noPostProcessing: Boolean) extends ParseStep[ByteString] { + override def canWorkWithPartialData = true + override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = { + inflater.setInput(reader.remainingData.toArray) - becomeWithRemaining(Inflate()(data), ByteString.empty, ctx) - } - } + val buffer = new Array[Byte](maxBytesPerChunk) + val read = inflater.inflate(buffer) - /** Inflate */ - case class Inflate()(data: ByteString) extends IntermediateState { - override def onPull(ctx: Context[ByteString]): SyncDirective = { - val buffer = new Array[Byte](maxBytesPerChunk) - val read = inflater.inflate(buffer) - if (read > 0) { - afterBytesRead(buffer, 0, read) - ctx.push(ByteString.fromArray(buffer, 0, read)) - } else { - val remaining = data.takeRight(inflater.getRemaining) - val next = - if (inflater.finished()) afterInflate - else StartInflate + reader.skip(reader.remainingSize - inflater.getRemaining) - becomeWithRemaining(next, remaining, ctx) + if (read > 0) { + afterBytesRead(buffer, 0, read) + val next = if (inflater.finished()) afterInflate else this + ParseResult(Some(ByteString.fromArray(buffer, 0, read)), next, noPostProcessing) + } else { + if (inflater.finished()) ParseResult(None, afterInflate, noPostProcessing) + else throw ByteStringParser.NeedMoreData + } } } - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = - throw new IllegalStateException("Don't expect a new Element") - } - - def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = { - become(next) - if (remaining.isEmpty) current.onPull(ctx) - else current.onPush(remaining, ctx) } } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala index 5f03dfd7a8..ca0c5a82d3 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala @@ -4,14 +4,15 @@ package akka.http.scaladsl.coding -import akka.util.ByteString -import akka.stream.stage._ - -import akka.http.impl.util.ByteReader -import java.util.zip.{ Inflater, CRC32, ZipException, Deflater } +import java.util.zip.{ CRC32, Deflater, Inflater, ZipException } +import akka.http.impl.engine.ws.{ ProtocolException, FrameEvent } import akka.http.scaladsl.model._ -import headers.HttpEncodings +import akka.http.scaladsl.model.headers.HttpEncodings +import akka.stream.Attributes +import akka.stream.io.ByteStringParser +import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep } +import akka.util.ByteString class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder { val encoding = HttpEncodings.gzip @@ -60,71 +61,55 @@ class GzipCompressor extends DeflateCompressor { } class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { - protected def createInflater(): Inflater = new Inflater(true) + override def createLogic(attr: Attributes) = new DecompressorParsingLogic { + override val inflater: Inflater = new Inflater(true) + override def afterInflate: ParseStep[ByteString] = ReadTrailer + override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = + crc32.update(buffer, offset, length) - def initial: State = Initial + trait Step extends ParseStep[ByteString] { + override def onTruncation(): Unit = failStage(new ZipException("Truncated GZIP stream")) + } + override val inflateState = new Inflate(false) with Step + startWith(ReadHeaders) - /** No bytes were received yet */ - case object Initial extends State { - def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = - if (data.isEmpty) ctx.pull() - else becomeWithRemaining(ReadHeaders, data, ctx) + /** Reading the header bytes */ + case object ReadHeaders extends Step { + override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = { + import reader._ + if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header + if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method + val flags = readByte() + skip(6) // skip MTIME, XFL and OS fields + if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields + if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name + if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment + if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header") - override def onPull(ctx: Context[ByteString]): SyncDirective = - if (ctx.isFinishing) { - ctx.finish() - } else super.onPull(ctx) - } + inflater.reset() + crc32.reset() + ParseResult(None, inflateState, false) + } + } + var crc32: CRC32 = new CRC32 + private def fail(msg: String) = throw new ZipException(msg) - var crc32: CRC32 = new CRC32 - protected def afterInflate: State = ReadTrailer - - /** Reading the header bytes */ - case object ReadHeaders extends ByteReadingState { - def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { - import reader._ - - if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header - if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method - val flags = readByte() - skip(6) // skip MTIME, XFL and OS fields - if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields - if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name - if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment - if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header") - - inflater.reset() - crc32.reset() - becomeWithRemaining(StartInflate, remainingData, ctx) + /** Reading the trailer */ + case object ReadTrailer extends Step { + override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = { + import reader._ + if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") + if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) + fail("Corrupt GZIP trailer ISIZE") + ParseResult(None, ReadHeaders, true) + } } } - - protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = - crc32.update(buffer, offset, length) - - /** Reading the trailer */ - case object ReadTrailer extends ByteReadingState { - def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { - import reader._ - - if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") - if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") - - becomeWithRemaining(Initial, remainingData, ctx) - } - } - - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() - private def crc16(data: ByteString) = { val crc = new CRC32 crc.update(data.toArray) crc.getValue.toInt & 0xFFFF } - - override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream")) - - private def fail(msg: String) = throw new ZipException(msg) } /** INTERNAL API */ diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala index ba59b93fd4..2256944139 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala @@ -6,7 +6,8 @@ package akka.http.scaladsl.coding import akka.http.scaladsl.model._ import akka.http.impl.util.StreamUtils -import akka.stream.stage.Stage +import akka.stream.FlowShape +import akka.stream.stage.{ GraphStage, Stage } import akka.util.ByteString import headers.HttpEncodings @@ -25,7 +26,7 @@ object NoCoding extends Coder with StreamDecoder { def newCompressor = NoCodingCompressor - def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] = () ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 501001c740..ab360ad898 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -43,7 +43,7 @@ object GraphStages { /** * INERNAL API */ - private[stream] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] { + private[akka] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] { val in = Inlet[T](Logging.simpleName(this) + ".in") val out = Outlet[T](Logging.simpleName(this) + ".out") override val shape = FlowShape(in, out) diff --git a/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala index 64817557d6..efe9209d6c 100644 --- a/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala +++ b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala @@ -19,30 +19,35 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]] final override val shape = FlowShape(bytesIn, objOut) class ParsingLogic extends GraphStageLogic(shape) { + var pullOnParserRequest = false override def preStart(): Unit = pull(bytesIn) setHandler(objOut, eagerTerminateOutput) private var buffer = ByteString.empty private var current: ParseStep[T] = FinishedParser + private var acceptUpstreamFinish: Boolean = true final protected def startWith(step: ParseStep[T]): Unit = current = step @tailrec private def doParse(): Unit = if (buffer.nonEmpty) { + val reader = new ByteReader(buffer) val cont = try { - val reader = new ByteReader(buffer) - val (elem, next) = current.parse(reader) - emit(objOut, elem) - if (next == FinishedParser) { + val parseResult = current.parse(reader) + acceptUpstreamFinish = parseResult.acceptUpstreamFinish + parseResult.result.map(emit(objOut, _)) + if (parseResult.nextStep == FinishedParser) { completeStage() false } else { buffer = reader.remainingData - current = next + current = parseResult.nextStep true } } catch { case NeedMoreData ⇒ + acceptUpstreamFinish = false + if (current.canWorkWithPartialData) buffer = reader.remainingData pull(bytesIn) false } @@ -51,11 +56,12 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]] setHandler(bytesIn, new InHandler { override def onPush(): Unit = { + pullOnParserRequest = false buffer ++= grab(bytesIn) doParse() } override def onUpstreamFinish(): Unit = - if (buffer.isEmpty) completeStage() + if (buffer.isEmpty && acceptUpstreamFinish) completeStage() else current.onTruncation() }) } @@ -63,13 +69,28 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]] object ByteStringParser { + /** + * @param result - parser can return some element for downstream or return None if no element was generated + * @param nextStep - next parser + * @param acceptUpstreamFinish - if true - stream will complete when received `onUpstreamFinish`, if "false" + * - onTruncation will be called + */ + case class ParseResult[+T](result: Option[T], + nextStep: ParseStep[T], + acceptUpstreamFinish: Boolean = true) + trait ParseStep[+T] { - def parse(reader: ByteReader): (T, ParseStep[T]) + /** + * Must return true when NeedMoreData will clean buffer. If returns false - next pulled + * data will be appended to existing data in buffer + */ + def canWorkWithPartialData: Boolean = false + def parse(reader: ByteReader): ParseResult[T] def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser") } object FinishedParser extends ParseStep[Nothing] { - def parse(reader: ByteReader) = + override def parse(reader: ByteReader) = throw new IllegalStateException("no initial parser installed: you must use startWith(...)") } @@ -83,6 +104,7 @@ object ByteStringParser { def remainingSize: Int = input.size - off def currentOffset: Int = off + def remainingData: ByteString = input.drop(off) def fromStartToHere: ByteString = input.take(off) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala index 8466a194a6..9b844235eb 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala @@ -257,7 +257,6 @@ object Zip { * '''Cancels when''' any downstream cancels */ object Unzip { - import akka.japi.function.Function /** * Creates a new `Unzip` stage with the specified output types. diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index ef051e81f1..1f19823b4b 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -300,6 +300,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: if (_interpreter != null) _interpreter.setHandler(conn(in), handler) } + /** + * Assign callbacks for linear stage for both [[Inlet]] and [[Outlet]] + */ + final protected def setHandlers(in: Inlet[_], out: Outlet[_], handler: InHandler with OutHandler): Unit = { + setHandler(in, handler) + setHandler(out, handler) + } + /** * Retrieves the current callback for the events on the given [[Inlet]] */