+htp #16516 rewrite Deflate/GzipDecompressor as StatefulStage to defuse zip bomb

Also, the tests have been DRY'd up.
This commit is contained in:
Johannes Rudolph 2014-12-10 16:55:50 +01:00
parent bfad068a70
commit 29d7a041f6
15 changed files with 510 additions and 396 deletions

View file

@ -0,0 +1,44 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.util
import akka.util.ByteString
import scala.util.control.NoStackTrace
/**
* A helper class to read from a ByteString statefully.
*
* INTERNAL API
*/
private[akka] class ByteReader(input: ByteString) {
import ByteReader.NeedMoreData
private[this] var off = 0
def currentOffset: Int = off
def remainingData: ByteString = input.drop(off)
def fromStartToHere: ByteString = input.take(currentOffset)
def readByte(): Int =
if (off < input.length) {
val x = input(off)
off += 1
x.toInt & 0xFF
} else throw NeedMoreData
def readShort(): Int = readByte() | (readByte() << 8)
def readInt(): Int = readShort() | (readShort() << 16)
def skip(numBytes: Int): Unit =
if (off + numBytes <= input.length) off += numBytes
else throw NeedMoreData
def skipZeroTerminatedString(): Unit = while (readByte() != 0) {}
}
/*
* INTERNAL API
*/
private[akka] object ByteReader {
val NeedMoreData = new Exception with NoStackTrace
}

View file

@ -0,0 +1,58 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.util
import akka.stream.stage.{ Directive, Context, StatefulStage }
import akka.util.ByteString
/**
* A helper class for writing parsers from ByteStrings.
*
* FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529
*
* INTERNAL API
*/
private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] {
protected def onTruncation(ctx: Context[Out]): Directive
/**
* Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of
* `ctx.pull()` to have truncation errors reported.
*/
abstract class IntermediateState extends State {
override def onPull(ctx: Context[Out]): Directive = pull(ctx)
def pull(ctx: Context[Out]): Directive =
if (ctx.isFinishing) onTruncation(ctx)
else ctx.pull()
}
/**
* A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun
* occurs the previous data is saved and the reading process is restarted from the beginning
* once more data was received.
*
* As [[read]] may be called several times for the same prefix of data, make sure not to
* manipulate any state during reading from the ByteReader.
*/
trait ByteReadingState extends IntermediateState {
def read(reader: ByteReader, ctx: Context[Out]): Directive
def onPush(data: ByteString, ctx: Context[Out]): Directive =
try {
val reader = new ByteReader(data)
read(reader, ctx)
} catch {
case ByteReader.NeedMoreData
become(TryAgain(data, this))
pull(ctx)
}
}
case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState {
def onPush(data: ByteString, ctx: Context[Out]): Directive = {
become(byteReadingState)
byteReadingState.onPush(previousData ++ data, ctx)
}
}
}

View file

@ -34,8 +34,8 @@ private[http] object StreamUtils {
* Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
*/
def byteStringTransformer(f: ByteString ByteString, finish: () ByteString): Flow[ByteString, ByteString] = {
val transformer = new PushPullStage[ByteString, ByteString] {
def byteStringTransformer(f: ByteString ByteString, finish: () ByteString): Stage[ByteString, ByteString] = {
new PushPullStage[ByteString, ByteString] {
override def onPush(element: ByteString, ctx: Context[ByteString]): Directive =
ctx.push(f(element))
@ -45,7 +45,6 @@ private[http] object StreamUtils {
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
}
Flow[ByteString].section(name("transformBytes"))(_.transform(() transformer))
}
def failedPublisher[T](ex: Throwable): Publisher[T] =
@ -94,6 +93,41 @@ private[http] object StreamUtils {
Flow[ByteString].section(name("sliceBytes"))(_.transform(() transformer))
}
def limitByteChunksStage(maxBytesPerChunk: Int): Stage[ByteString, ByteString] =
new StatefulStage[ByteString, ByteString] {
def initial = WaitingForData
case object WaitingForData extends State {
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
if (elem.size <= maxBytesPerChunk) ctx.push(elem)
else {
become(DeliveringData(elem.drop(maxBytesPerChunk)))
ctx.push(elem.take(maxBytesPerChunk))
}
}
case class DeliveringData(remaining: ByteString) extends State {
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
throw new IllegalStateException("Not expecting data")
override def onPull(ctx: Context[ByteString]): Directive = {
val toPush = remaining.take(maxBytesPerChunk)
val toKeep = remaining.drop(maxBytesPerChunk)
become {
if (toKeep.isEmpty) WaitingForData
else DeliveringData(toKeep)
}
if (ctx.isFinishing) ctx.pushAndFinish(toPush)
else ctx.push(toPush)
}
}
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
current match {
case WaitingForData ctx.finish()
case _: DeliveringData ctx.absorbTermination()
}
}
/**
* Applies a sequence of transformers on one source and returns a sequence of sources with the result. The input source
* will only be traversed once.

View file

@ -121,7 +121,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll {
}
def duplicateBytesTransformer(): Flow[ByteString, ByteString] =
StreamUtils.byteStringTransformer(doubleChars, () trailer)
Flow[ByteString].transform(() StreamUtils.byteStringTransformer(doubleChars, () trailer))
def trailer: ByteString = ByteString("--dup")
def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b Seq(b, b)): _*)

File diff suppressed because one or more lines are too long

View file

@ -4,6 +4,7 @@
package akka.http.coding
import akka.stream.stage.{ Directive, Context, PushStage, Stage }
import akka.util.ByteString
import org.scalatest.WordSpec
import akka.http.model._
@ -26,15 +27,15 @@ class DecoderSpec extends WordSpec with CodecSpecSupport {
}
def dummyDecompress(s: String): String = dummyDecompress(ByteString(s, "UTF8")).decodeString("UTF8")
def dummyDecompress(bytes: ByteString): ByteString = DummyDecompressor.decompress(bytes)
def dummyDecompress(bytes: ByteString): ByteString = DummyDecoder.decode(bytes)
case object DummyDecoder extends Decoder {
case object DummyDecoder extends StreamDecoder {
val encoding = HttpEncodings.compress
def newDecompressor = DummyDecompressor
}
case object DummyDecompressor extends Decompressor {
def decompress(buffer: ByteString): ByteString = buffer ++ ByteString("compressed")
def finish(): ByteString = ByteString.empty
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] =
() new PushStage[ByteString, ByteString] {
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
ctx.push(elem ++ ByteString("compressed"))
}
}
}

View file

@ -5,78 +5,25 @@
package akka.http.coding
import akka.util.ByteString
import akka.http.util._
import org.scalatest.WordSpec
import akka.http.model._
import HttpMethods.POST
import java.io.ByteArrayOutputStream
import java.util.zip.{ DeflaterOutputStream, InflaterOutputStream }
import java.io.{ InputStream, OutputStream }
import java.util.zip._
class DeflateSpec extends WordSpec with CodecSpecSupport {
class DeflateSpec extends CoderSpec {
protected def Coder: Coder with StreamDecoder = Deflate
"The Deflate codec" should {
"properly encode a small string" in {
streamInflate(ourDeflate(smallTextBytes)) should readAs(smallText)
}
"properly decode a small string" in {
ourInflate(streamDeflate(smallTextBytes)) should readAs(smallText)
}
"properly round-trip encode/decode a small string" in {
ourInflate(ourDeflate(smallTextBytes)) should readAs(smallText)
}
"properly encode a large string" in {
streamInflate(ourDeflate(largeTextBytes)) should readAs(largeText)
}
"properly decode a large string" in {
ourInflate(streamDeflate(largeTextBytes)) should readAs(largeText)
}
"properly round-trip encode/decode a large string" in {
ourInflate(ourDeflate(largeTextBytes)) should readAs(largeText)
}
"properly round-trip encode/decode an HttpRequest" in {
val request = HttpRequest(POST, entity = HttpEntity(largeText))
Deflate.decode(Deflate.encode(request)) should equal(request)
}
"provide a better compression ratio than the standard Deflater/Inflater streams" in {
ourDeflate(largeTextBytes).length should be < streamDeflate(largeTextBytes).length
}
"support chunked round-trip encoding/decoding" in {
val chunks = largeTextBytes.grouped(512).toVector
val comp = Deflate.newCompressor
val decomp = Deflate.newDecompressor
val chunks2 =
chunks.map { chunk
decomp.decompress(comp.compressAndFlush(chunk))
} :+
decomp.decompress(comp.finish())
chunks2.join should readAs(largeText)
}
"works for any split in prefix + suffix" in {
val compressed = streamDeflate(smallTextBytes)
def tryWithPrefixOfSize(prefixSize: Int): Unit = {
val decomp = Deflate.newDecompressor
val prefix = compressed.take(prefixSize)
val suffix = compressed.drop(prefixSize)
protected def newDecodedInputStream(underlying: InputStream): InputStream =
new InflaterInputStream(underlying)
decomp.decompress(prefix) ++ decomp.decompress(suffix) should readAs(smallText)
}
(0 to compressed.size).foreach(tryWithPrefixOfSize)
}
}
protected def newEncodedOutputStream(underlying: OutputStream): OutputStream =
new DeflaterOutputStream(underlying)
def ourDeflate(bytes: ByteString): ByteString = Deflate.encode(bytes)
def ourInflate(bytes: ByteString): ByteString = Deflate.decode(bytes)
protected def corruptInputMessage: Option[String] = Some("invalid code -- missing end-of-block")
def streamDeflate(bytes: ByteString) = {
val output = new ByteArrayOutputStream()
val dos = new DeflaterOutputStream(output); dos.write(bytes.toArray); dos.close()
ByteString(output.toByteArray)
override def extraTests(): Unit = {
"throw early if header is corrupt" in {
val ex = the[DataFormatException] thrownBy ourDecode(ByteString(0, 1, 2, 3, 4))
ex.getMessage should equal("incorrect header check")
}
def streamInflate(bytes: ByteString) = {
val output = new ByteArrayOutputStream()
val ios = new InflaterOutputStream(output); ios.write(bytes.toArray); ios.close()
ByteString(output.toByteArray)
}
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,16 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.coding
import java.io.{ OutputStream, InputStream }
class NoCodingSpec extends CoderSpec {
protected def Coder: Coder with StreamDecoder = NoCoding
protected def corruptInputMessage: Option[String] = None // all input data is valid
protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = underlying
protected def newDecodedInputStream(underlying: InputStream): InputStream = underlying
}

View file

@ -166,7 +166,7 @@ class CodingDirectivesSpec extends RoutingSpec {
response should haveContentEncoding(gzip)
chunks.size shouldEqual (11 + 1) // 11 regular + the last one
val bytes = chunks.foldLeft(ByteString.empty)(_ ++ _.data)
Gzip.newDecompressor.decompress(bytes) should readAs(text)
Gzip.decode(bytes) should readAs(text)
}
}
}

View file

@ -6,6 +6,7 @@ package akka.http.coding
import akka.http.model._
import akka.http.util.StreamUtils
import akka.stream.stage.Stage
import akka.util.ByteString
import headers.HttpEncoding
import akka.stream.scaladsl.Flow
@ -18,31 +19,38 @@ trait Decoder {
decodeData(message).withHeaders(message.headers filterNot Encoder.isContentEncodingHeader)
else message.self
def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
mapper.transformDataBytes(t, newDecodeTransfomer)
def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = mapper.transformDataBytes(t, decoderFlow)
def decode(input: ByteString): ByteString = newDecompressor.decompressAndFinish(input)
def maxBytesPerChunk: Int
def withMaxBytesPerChunk(maxBytesPerChunk: Int): Decoder
def newDecompressor: Decompressor
def decoderFlow: Flow[ByteString, ByteString]
def decode(input: ByteString): ByteString
}
object Decoder {
val MaxBytesPerChunkDefault: Int = 65536
}
def newDecodeTransfomer(): Flow[ByteString, ByteString] = {
val decompressor = newDecompressor
/** A decoder that is implemented in terms of a [[Stage]] */
trait StreamDecoder extends Decoder { outer
protected def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString]
def decodeChunk(bytes: ByteString): ByteString = decompressor.decompress(bytes)
def finish(): ByteString = decompressor.finish()
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
new StreamDecoder {
def encoding: HttpEncoding = outer.encoding
override def maxBytesPerChunk: Int = newMaxBytesPerChunk
StreamUtils.byteStringTransformer(decodeChunk, finish)
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] =
outer.newDecompressorStage(maxBytesPerChunk)
}
}
/** A stateful object representing ongoing decompression. */
abstract class Decompressor {
/** Decompress the buffer and return decompressed data. */
def decompress(input: ByteString): ByteString
/** Flushes potential remaining data from any internal buffers and may report on truncation errors */
def finish(): ByteString
/** Combines decompress and finish */
def decompressAndFinish(input: ByteString): ByteString = decompress(input) ++ finish()
def decoderFlow: Flow[ByteString, ByteString] =
Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk))
def decode(input: ByteString): ByteString = decodeWithLimits(input)
def decodeWithLimits(input: ByteString, maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString =
StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get
def decodeFromIterator(input: Iterator[ByteString], maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString =
StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get
}

View file

@ -4,27 +4,21 @@
package akka.http.coding
import java.io.OutputStream
import java.util.zip.{ DataFormatException, ZipException, Inflater, Deflater }
import java.util.zip.{ Inflater, Deflater }
import akka.stream.stage._
import akka.util.{ ByteStringBuilder, ByteString }
import scala.annotation.tailrec
import akka.http.util._
import akka.http.model._
import headers.HttpEncodings
import akka.http.model.headers.HttpEncodings
class Deflate(val messageFilter: HttpMessage Boolean) extends Coder {
class Deflate(val messageFilter: HttpMessage Boolean) extends Coder with StreamDecoder {
val encoding = HttpEncodings.deflate
def newCompressor = new DeflateCompressor
def newDecompressor = new DeflateDecompressor
}
/**
* An encoder and decoder for the HTTP 'deflate' encoding.
*/
object Deflate extends Deflate(Encoder.DefaultFilter) {
def apply(messageFilter: HttpMessage Boolean) = new Deflate(messageFilter)
def newDecompressorStage(maxBytesPerChunk: Int) = () new DeflateDecompressor(maxBytesPerChunk)
}
object Deflate extends Deflate(Encoder.DefaultFilter)
class DeflateCompressor extends Compressor {
protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false)
@ -88,27 +82,57 @@ class DeflateCompressor extends Compressor {
new Array[Byte](size)
}
class DeflateDecompressor extends Decompressor {
protected lazy val inflater = new Inflater()
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
protected def createInflater() = new Inflater()
def decompress(input: ByteString): ByteString =
try {
inflater.setInput(input.toArray)
drain(new Array[Byte](input.length * 2))
} catch {
case e: DataFormatException
throw new ZipException(e.getMessage.toOption getOrElse "Invalid ZLIB data format")
def initial: State = StartInflate
def afterInflate: State = StartInflate
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.finish()
}
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] {
protected def createInflater(): Inflater
val inflater = createInflater()
protected def afterInflate: State
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
/** Start inflating */
case object StartInflate extends State {
def onPush(data: ByteString, ctx: Context[ByteString]): Directive = {
require(inflater.needsInput())
inflater.setInput(data.toArray)
becomeWithRemaining(Inflate()(data), ByteString.empty, ctx)
}
}
@tailrec protected final def drain(buffer: Array[Byte], result: ByteString = ByteString.empty): ByteString = {
val len = inflater.inflate(buffer)
if (len > 0) drain(buffer, result ++ ByteString.fromArray(buffer, 0, len))
else if (inflater.needsDictionary) throw new ZipException("ZLIB dictionary missing")
else result
/** Inflate */
case class Inflate()(data: ByteString) extends IntermediateState {
override def onPull(ctx: Context[ByteString]): Directive = {
val buffer = new Array[Byte](maxBytesPerChunk)
val read = inflater.inflate(buffer)
if (read > 0) {
afterBytesRead(buffer, 0, read)
ctx.push(ByteString.fromArray(buffer, 0, read))
} else {
val remaining = data.takeRight(inflater.getRemaining)
val next =
if (inflater.finished()) afterInflate
else StartInflate
becomeWithRemaining(next, remaining, ctx)
}
}
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
throw new IllegalStateException("Don't expect a new Element")
}
def finish(): ByteString = {
inflater.end()
ByteString.empty
def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = {
become(next)
if (remaining.isEmpty) current.onPull(ctx)
else current.onPush(remaining, ctx)
}
}

View file

@ -6,6 +6,7 @@ package akka.http.coding
import akka.http.model._
import akka.http.util.StreamUtils
import akka.stream.stage.Stage
import akka.util.ByteString
import headers._
import akka.stream.scaladsl.Flow
@ -21,13 +22,13 @@ trait Encoder {
else message.self
def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
mapper.transformDataBytes(t, newEncodeTransformer)
mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer))
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
def newCompressor: Compressor
def newEncodeTransformer(): Flow[ByteString, ByteString] = {
def newEncodeTransformer(): Stage[ByteString, ByteString] = {
val compressor = newCompressor
def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes)

View file

@ -4,19 +4,19 @@
package akka.http.coding
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
import akka.util.ByteString
import akka.stream.stage._
import akka.http.util.ByteReader
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
import scala.annotation.tailrec
import akka.http.model._
import headers.HttpEncodings
import scala.util.control.{ NonFatal, NoStackTrace }
class Gzip(val messageFilter: HttpMessage Boolean) extends Coder {
class Gzip(val messageFilter: HttpMessage Boolean) extends Coder with StreamDecoder {
val encoding = HttpEncodings.gzip
def newCompressor = new GzipCompressor
def newDecompressor = new GzipDecompressor
def newDecompressorStage(maxBytesPerChunk: Int) = () new GzipDecompressor(maxBytesPerChunk)
}
/**
@ -59,24 +59,29 @@ class GzipCompressor extends DeflateCompressor {
}
}
/** A suspendable gzip decompressor */
class GzipDecompressor extends DeflateDecompressor {
override protected lazy val inflater = new Inflater(true) // disable ZLIB headers
override def decompress(input: ByteString): ByteString = DecompressionStateMachine.run(input)
override def finish(): ByteString =
if (DecompressionStateMachine.isFinished) ByteString.empty
else fail("Truncated GZIP stream")
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
protected def createInflater(): Inflater = new Inflater(true)
import GzipDecompressor._
def initial: State = Initial
object DecompressionStateMachine extends StateMachine {
def isFinished: Boolean = currentState == finished
/** No bytes were received yet */
case object Initial extends State {
def onPush(data: ByteString, ctx: Context[ByteString]): Directive =
if (data.isEmpty) ctx.pull()
else becomeWithRemaining(ReadHeaders, data, ctx)
def initialState = finished
override def onPull(ctx: Context[ByteString]): Directive =
if (ctx.isFinishing) {
ctx.finish()
} else super.onPull(ctx)
}
private def readHeaders(data: ByteString): Action =
try {
val reader = new ByteReader(data)
var crc32: CRC32 = new CRC32
protected def afterInflate: State = ReadTrailer
/** Reading the header bytes */
case object ReadHeaders extends ByteReadingState {
def read(reader: ByteReader, ctx: Context[ByteString]): Directive = {
import reader._
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
@ -84,59 +89,50 @@ class GzipDecompressor extends DeflateDecompressor {
val flags = readByte()
skip(6) // skip MTIME, XFL and OS fields
if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields
if ((flags & 8) > 0) while (readByte() != 0) {} // skip optional file name
if ((flags & 16) > 0) while (readByte() != 0) {} // skip optional file comment
if ((flags & 2) > 0 && crc16(data.take(currentOffset)) != readShort()) fail("Corrupt GZIP header")
ContinueWith(deflate(new CRC32), remainingData)
} catch {
case ByteReader.NeedMoreData SuspendAndRetryWithMoreData
}
private def deflate(checkSum: CRC32)(data: ByteString): Action = {
assert(inflater.needsInput())
inflater.setInput(data.toArray)
val output = drain(new Array[Byte](data.length * 2))
checkSum.update(output.toArray)
if (inflater.finished()) EmitAndContinueWith(output, readTrailer(checkSum), data.takeRight(inflater.getRemaining))
else EmitAndSuspend(output)
}
private def readTrailer(checkSum: CRC32)(data: ByteString): Action =
try {
val reader = new ByteReader(data)
import reader._
if (readInt() != checkSum.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name
if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShort()) fail("Corrupt GZIP header")
inflater.reset()
checkSum.reset()
ContinueWith(finished, remainingData) // start over to support multiple concatenated gzip streams
} catch {
case ByteReader.NeedMoreData SuspendAndRetryWithMoreData
crc32.reset()
becomeWithRemaining(StartInflate, remainingData, ctx)
}
}
lazy val finished: ByteString Action =
data if (data.nonEmpty) ContinueWith(readHeaders, data) else SuspendAndRetryWithMoreData
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
crc32.update(buffer, offset, length)
/** Reading the trailer */
case object ReadTrailer extends ByteReadingState {
def read(reader: ByteReader, ctx: Context[ByteString]): Directive = {
import reader._
if (readInt() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
becomeWithRemaining(Initial, remainingData, ctx)
}
}
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
private def crc16(data: ByteString) = {
val crc = new CRC32
crc.update(data.toArray)
crc.getValue.toInt & 0xFFFF
}
}
override protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.fail(new ZipException("Truncated GZIP stream"))
private def fail(msg: String) = throw new ZipException(msg)
}
/** INTERNAL API */
private[http] object GzipDecompressor {
// RFC 1952: http://tools.ietf.org/html/rfc1952 section 2.2
val Header = ByteString(
31, // ID1
-117, // ID2
0x1F, // ID1
0x8B, // ID2
8, // CM = Deflate
0, // FLG
0, // MTIME 1
@ -146,76 +142,4 @@ private[http] object GzipDecompressor {
0, // XFL
0 // OS
)
class ByteReader(input: ByteString) {
import ByteReader.NeedMoreData
private[this] var off = 0
def readByte(): Int =
if (off < input.length) {
val x = input(off)
off += 1
x.toInt & 0xFF
} else throw NeedMoreData
def readShort(): Int = readByte() | (readByte() << 8)
def readInt(): Int = readShort() | (readShort() << 16)
def skip(numBytes: Int): Unit =
if (off + numBytes <= input.length) off += numBytes
else throw NeedMoreData
def currentOffset: Int = off
def remainingData: ByteString = input.drop(off)
}
object ByteReader {
val NeedMoreData = new Exception with NoStackTrace
}
/** A simple state machine implementation for suspendable parsing */
trait StateMachine {
sealed trait Action
/** Cache the current input and suspend to wait for more data */
case object SuspendAndRetryWithMoreData extends Action
/** Emit some output and suspend in the current state and wait for more data */
case class EmitAndSuspend(output: ByteString) extends Action
/** Proceed to the nextState and immediately run it with the remainingInput */
case class ContinueWith(nextState: State, remainingInput: ByteString) extends Action
/** Emit some output and then proceed to the nextState and immediately run it with the remainingInput */
case class EmitAndContinueWith(output: ByteString, nextState: State, remainingInput: ByteString) extends Action
type State = ByteString Action
def initialState: State
private[this] var state: State = initialState
def currentState: State = state
/** Run the state machine with the current input */
final def run(input: ByteString): ByteString = {
@tailrec def rec(input: ByteString, result: ByteString = ByteString.empty): ByteString =
state(input) match {
case SuspendAndRetryWithMoreData
val oldState = state
state = { newData
state = oldState
oldState(input ++ newData)
}
result
case EmitAndSuspend(output) result ++ output
case ContinueWith(next, remainingInput)
state = next
if (remainingInput.nonEmpty) rec(remainingInput, result)
else result
case EmitAndContinueWith(output, next, remainingInput)
state = next
rec(remainingInput, result ++ output)
}
try rec(input)
catch {
case NonFatal(e)
state = failState
throw e
}
}
private def failState: State = _ throw new IllegalStateException("Trying to reuse failed decompressor.")
}
}

View file

@ -5,13 +5,15 @@
package akka.http.coding
import akka.http.model._
import akka.http.util.StreamUtils
import akka.stream.stage.Stage
import akka.util.ByteString
import headers.HttpEncodings
/**
* An encoder and decoder for the HTTP 'identity' encoding.
*/
object NoCoding extends Coder {
object NoCoding extends Coder with StreamDecoder {
val encoding = HttpEncodings.identity
override def encode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = message.self
@ -22,7 +24,9 @@ object NoCoding extends Coder {
val messageFilter: HttpMessage Boolean = _ false
def newCompressor = NoCodingCompressor
def newDecompressor = NoCodingDecompressor
def newDecompressorStage(maxBytesPerChunk: Int): () Stage[ByteString, ByteString] =
() StreamUtils.limitByteChunksStage(maxBytesPerChunk)
}
object NoCodingCompressor extends Compressor {
@ -33,7 +37,3 @@ object NoCodingCompressor extends Compressor {
def compressAndFlush(input: ByteString): ByteString = input
def compressAndFinish(input: ByteString): ByteString = input
}
object NoCodingDecompressor extends Decompressor {
def decompress(input: ByteString): ByteString = input
def finish(): ByteString = ByteString.empty
}