+htp #16516 rewrite Deflate/GzipDecompressor as StatefulStage to defuse zip bomb
Also, the tests have been DRY'd up.
This commit is contained in:
parent
bfad068a70
commit
29d7a041f6
15 changed files with 510 additions and 396 deletions
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.util
|
||||
|
||||
import akka.util.ByteString
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
/**
|
||||
* A helper class to read from a ByteString statefully.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] class ByteReader(input: ByteString) {
|
||||
import ByteReader.NeedMoreData
|
||||
|
||||
private[this] var off = 0
|
||||
|
||||
def currentOffset: Int = off
|
||||
def remainingData: ByteString = input.drop(off)
|
||||
def fromStartToHere: ByteString = input.take(currentOffset)
|
||||
|
||||
def readByte(): Int =
|
||||
if (off < input.length) {
|
||||
val x = input(off)
|
||||
off += 1
|
||||
x.toInt & 0xFF
|
||||
} else throw NeedMoreData
|
||||
def readShort(): Int = readByte() | (readByte() << 8)
|
||||
def readInt(): Int = readShort() | (readShort() << 16)
|
||||
def skip(numBytes: Int): Unit =
|
||||
if (off + numBytes <= input.length) off += numBytes
|
||||
else throw NeedMoreData
|
||||
def skipZeroTerminatedString(): Unit = while (readByte() != 0) {}
|
||||
}
|
||||
|
||||
/*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] object ByteReader {
|
||||
val NeedMoreData = new Exception with NoStackTrace
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.util
|
||||
|
||||
import akka.stream.stage.{ Directive, Context, StatefulStage }
|
||||
import akka.util.ByteString
|
||||
|
||||
/**
|
||||
* A helper class for writing parsers from ByteStrings.
|
||||
*
|
||||
* FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] {
|
||||
protected def onTruncation(ctx: Context[Out]): Directive
|
||||
|
||||
/**
|
||||
* Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of
|
||||
* `ctx.pull()` to have truncation errors reported.
|
||||
*/
|
||||
abstract class IntermediateState extends State {
|
||||
override def onPull(ctx: Context[Out]): Directive = pull(ctx)
|
||||
def pull(ctx: Context[Out]): Directive =
|
||||
if (ctx.isFinishing) onTruncation(ctx)
|
||||
else ctx.pull()
|
||||
}
|
||||
|
||||
/**
|
||||
* A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun
|
||||
* occurs the previous data is saved and the reading process is restarted from the beginning
|
||||
* once more data was received.
|
||||
*
|
||||
* As [[read]] may be called several times for the same prefix of data, make sure not to
|
||||
* manipulate any state during reading from the ByteReader.
|
||||
*/
|
||||
trait ByteReadingState extends IntermediateState {
|
||||
def read(reader: ByteReader, ctx: Context[Out]): Directive
|
||||
|
||||
def onPush(data: ByteString, ctx: Context[Out]): Directive =
|
||||
try {
|
||||
val reader = new ByteReader(data)
|
||||
read(reader, ctx)
|
||||
} catch {
|
||||
case ByteReader.NeedMoreData ⇒
|
||||
become(TryAgain(data, this))
|
||||
pull(ctx)
|
||||
}
|
||||
}
|
||||
case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState {
|
||||
def onPush(data: ByteString, ctx: Context[Out]): Directive = {
|
||||
become(byteReadingState)
|
||||
byteReadingState.onPush(previousData ++ data, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -34,8 +34,8 @@ private[http] object StreamUtils {
|
|||
* Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete
|
||||
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
|
||||
*/
|
||||
def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Flow[ByteString, ByteString] = {
|
||||
val transformer = new PushPullStage[ByteString, ByteString] {
|
||||
def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Stage[ByteString, ByteString] = {
|
||||
new PushPullStage[ByteString, ByteString] {
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): Directive =
|
||||
ctx.push(f(element))
|
||||
|
||||
|
|
@ -45,7 +45,6 @@ private[http] object StreamUtils {
|
|||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
}
|
||||
Flow[ByteString].section(name("transformBytes"))(_.transform(() ⇒ transformer))
|
||||
}
|
||||
|
||||
def failedPublisher[T](ex: Throwable): Publisher[T] =
|
||||
|
|
@ -94,6 +93,41 @@ private[http] object StreamUtils {
|
|||
Flow[ByteString].section(name("sliceBytes"))(_.transform(() ⇒ transformer))
|
||||
}
|
||||
|
||||
def limitByteChunksStage(maxBytesPerChunk: Int): Stage[ByteString, ByteString] =
|
||||
new StatefulStage[ByteString, ByteString] {
|
||||
def initial = WaitingForData
|
||||
case object WaitingForData extends State {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
|
||||
if (elem.size <= maxBytesPerChunk) ctx.push(elem)
|
||||
else {
|
||||
become(DeliveringData(elem.drop(maxBytesPerChunk)))
|
||||
ctx.push(elem.take(maxBytesPerChunk))
|
||||
}
|
||||
}
|
||||
case class DeliveringData(remaining: ByteString) extends State {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
|
||||
throw new IllegalStateException("Not expecting data")
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): Directive = {
|
||||
val toPush = remaining.take(maxBytesPerChunk)
|
||||
val toKeep = remaining.drop(maxBytesPerChunk)
|
||||
|
||||
become {
|
||||
if (toKeep.isEmpty) WaitingForData
|
||||
else DeliveringData(toKeep)
|
||||
}
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(toPush)
|
||||
else ctx.push(toPush)
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
|
||||
current match {
|
||||
case WaitingForData ⇒ ctx.finish()
|
||||
case _: DeliveringData ⇒ ctx.absorbTermination()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a sequence of transformers on one source and returns a sequence of sources with the result. The input source
|
||||
* will only be traversed once.
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll {
|
|||
}
|
||||
|
||||
def duplicateBytesTransformer(): Flow[ByteString, ByteString] =
|
||||
StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer)
|
||||
Flow[ByteString].transform(() ⇒ StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer))
|
||||
|
||||
def trailer: ByteString = ByteString("--dup")
|
||||
def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b ⇒ Seq(b, b)): _*)
|
||||
|
|
|
|||
154
akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala
Normal file
154
akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
package akka.http.coding
|
||||
|
||||
import akka.stream.stage.{ Directive, Context, PushStage, Stage }
|
||||
import akka.util.ByteString
|
||||
import org.scalatest.WordSpec
|
||||
import akka.http.model._
|
||||
|
|
@ -26,15 +27,15 @@ class DecoderSpec extends WordSpec with CodecSpecSupport {
|
|||
}
|
||||
|
||||
def dummyDecompress(s: String): String = dummyDecompress(ByteString(s, "UTF8")).decodeString("UTF8")
|
||||
def dummyDecompress(bytes: ByteString): ByteString = DummyDecompressor.decompress(bytes)
|
||||
def dummyDecompress(bytes: ByteString): ByteString = DummyDecoder.decode(bytes)
|
||||
|
||||
case object DummyDecoder extends Decoder {
|
||||
case object DummyDecoder extends StreamDecoder {
|
||||
val encoding = HttpEncodings.compress
|
||||
def newDecompressor = DummyDecompressor
|
||||
}
|
||||
|
||||
case object DummyDecompressor extends Decompressor {
|
||||
def decompress(buffer: ByteString): ByteString = buffer ++ ByteString("compressed")
|
||||
def finish(): ByteString = ByteString.empty
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
() ⇒ new PushStage[ByteString, ByteString] {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
|
||||
ctx.push(elem ++ ByteString("compressed"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,78 +5,25 @@
|
|||
package akka.http.coding
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.http.util._
|
||||
import org.scalatest.WordSpec
|
||||
import akka.http.model._
|
||||
import HttpMethods.POST
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.util.zip.{ DeflaterOutputStream, InflaterOutputStream }
|
||||
import java.io.{ InputStream, OutputStream }
|
||||
import java.util.zip._
|
||||
|
||||
class DeflateSpec extends WordSpec with CodecSpecSupport {
|
||||
class DeflateSpec extends CoderSpec {
|
||||
protected def Coder: Coder with StreamDecoder = Deflate
|
||||
|
||||
"The Deflate codec" should {
|
||||
"properly encode a small string" in {
|
||||
streamInflate(ourDeflate(smallTextBytes)) should readAs(smallText)
|
||||
}
|
||||
"properly decode a small string" in {
|
||||
ourInflate(streamDeflate(smallTextBytes)) should readAs(smallText)
|
||||
}
|
||||
"properly round-trip encode/decode a small string" in {
|
||||
ourInflate(ourDeflate(smallTextBytes)) should readAs(smallText)
|
||||
}
|
||||
"properly encode a large string" in {
|
||||
streamInflate(ourDeflate(largeTextBytes)) should readAs(largeText)
|
||||
}
|
||||
"properly decode a large string" in {
|
||||
ourInflate(streamDeflate(largeTextBytes)) should readAs(largeText)
|
||||
}
|
||||
"properly round-trip encode/decode a large string" in {
|
||||
ourInflate(ourDeflate(largeTextBytes)) should readAs(largeText)
|
||||
}
|
||||
"properly round-trip encode/decode an HttpRequest" in {
|
||||
val request = HttpRequest(POST, entity = HttpEntity(largeText))
|
||||
Deflate.decode(Deflate.encode(request)) should equal(request)
|
||||
}
|
||||
"provide a better compression ratio than the standard Deflater/Inflater streams" in {
|
||||
ourDeflate(largeTextBytes).length should be < streamDeflate(largeTextBytes).length
|
||||
}
|
||||
"support chunked round-trip encoding/decoding" in {
|
||||
val chunks = largeTextBytes.grouped(512).toVector
|
||||
val comp = Deflate.newCompressor
|
||||
val decomp = Deflate.newDecompressor
|
||||
val chunks2 =
|
||||
chunks.map { chunk ⇒
|
||||
decomp.decompress(comp.compressAndFlush(chunk))
|
||||
} :+
|
||||
decomp.decompress(comp.finish())
|
||||
chunks2.join should readAs(largeText)
|
||||
}
|
||||
"works for any split in prefix + suffix" in {
|
||||
val compressed = streamDeflate(smallTextBytes)
|
||||
def tryWithPrefixOfSize(prefixSize: Int): Unit = {
|
||||
val decomp = Deflate.newDecompressor
|
||||
val prefix = compressed.take(prefixSize)
|
||||
val suffix = compressed.drop(prefixSize)
|
||||
protected def newDecodedInputStream(underlying: InputStream): InputStream =
|
||||
new InflaterInputStream(underlying)
|
||||
|
||||
decomp.decompress(prefix) ++ decomp.decompress(suffix) should readAs(smallText)
|
||||
}
|
||||
(0 to compressed.size).foreach(tryWithPrefixOfSize)
|
||||
}
|
||||
}
|
||||
protected def newEncodedOutputStream(underlying: OutputStream): OutputStream =
|
||||
new DeflaterOutputStream(underlying)
|
||||
|
||||
def ourDeflate(bytes: ByteString): ByteString = Deflate.encode(bytes)
|
||||
def ourInflate(bytes: ByteString): ByteString = Deflate.decode(bytes)
|
||||
protected def corruptInputMessage: Option[String] = Some("invalid code -- missing end-of-block")
|
||||
|
||||
def streamDeflate(bytes: ByteString) = {
|
||||
val output = new ByteArrayOutputStream()
|
||||
val dos = new DeflaterOutputStream(output); dos.write(bytes.toArray); dos.close()
|
||||
ByteString(output.toByteArray)
|
||||
override def extraTests(): Unit = {
|
||||
"throw early if header is corrupt" in {
|
||||
val ex = the[DataFormatException] thrownBy ourDecode(ByteString(0, 1, 2, 3, 4))
|
||||
ex.getMessage should equal("incorrect header check")
|
||||
}
|
||||
|
||||
def streamInflate(bytes: ByteString) = {
|
||||
val output = new ByteArrayOutputStream()
|
||||
val ios = new InflaterOutputStream(output); ios.write(bytes.toArray); ios.close()
|
||||
ByteString(output.toByteArray)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.coding
|
||||
|
||||
import java.io.{ OutputStream, InputStream }
|
||||
|
||||
class NoCodingSpec extends CoderSpec {
|
||||
protected def Coder: Coder with StreamDecoder = NoCoding
|
||||
|
||||
protected def corruptInputMessage: Option[String] = None // all input data is valid
|
||||
|
||||
protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = underlying
|
||||
protected def newDecodedInputStream(underlying: InputStream): InputStream = underlying
|
||||
}
|
||||
|
|
@ -166,7 +166,7 @@ class CodingDirectivesSpec extends RoutingSpec {
|
|||
response should haveContentEncoding(gzip)
|
||||
chunks.size shouldEqual (11 + 1) // 11 regular + the last one
|
||||
val bytes = chunks.foldLeft(ByteString.empty)(_ ++ _.data)
|
||||
Gzip.newDecompressor.decompress(bytes) should readAs(text)
|
||||
Gzip.decode(bytes) should readAs(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package akka.http.coding
|
|||
|
||||
import akka.http.model._
|
||||
import akka.http.util.StreamUtils
|
||||
import akka.stream.stage.Stage
|
||||
import akka.util.ByteString
|
||||
import headers.HttpEncoding
|
||||
import akka.stream.scaladsl.Flow
|
||||
|
|
@ -18,31 +19,38 @@ trait Decoder {
|
|||
decodeData(message).withHeaders(message.headers filterNot Encoder.isContentEncodingHeader)
|
||||
else message.self
|
||||
|
||||
def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
|
||||
mapper.transformDataBytes(t, newDecodeTransfomer)
|
||||
def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = mapper.transformDataBytes(t, decoderFlow)
|
||||
|
||||
def decode(input: ByteString): ByteString = newDecompressor.decompressAndFinish(input)
|
||||
def maxBytesPerChunk: Int
|
||||
def withMaxBytesPerChunk(maxBytesPerChunk: Int): Decoder
|
||||
|
||||
def newDecompressor: Decompressor
|
||||
def decoderFlow: Flow[ByteString, ByteString]
|
||||
def decode(input: ByteString): ByteString
|
||||
}
|
||||
object Decoder {
|
||||
val MaxBytesPerChunkDefault: Int = 65536
|
||||
}
|
||||
|
||||
def newDecodeTransfomer(): Flow[ByteString, ByteString] = {
|
||||
val decompressor = newDecompressor
|
||||
/** A decoder that is implemented in terms of a [[Stage]] */
|
||||
trait StreamDecoder extends Decoder { outer ⇒
|
||||
protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString]
|
||||
|
||||
def decodeChunk(bytes: ByteString): ByteString = decompressor.decompress(bytes)
|
||||
def finish(): ByteString = decompressor.finish()
|
||||
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
|
||||
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
|
||||
new StreamDecoder {
|
||||
def encoding: HttpEncoding = outer.encoding
|
||||
override def maxBytesPerChunk: Int = newMaxBytesPerChunk
|
||||
|
||||
StreamUtils.byteStringTransformer(decodeChunk, finish)
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
outer.newDecompressorStage(maxBytesPerChunk)
|
||||
}
|
||||
}
|
||||
|
||||
/** A stateful object representing ongoing decompression. */
|
||||
abstract class Decompressor {
|
||||
/** Decompress the buffer and return decompressed data. */
|
||||
def decompress(input: ByteString): ByteString
|
||||
|
||||
/** Flushes potential remaining data from any internal buffers and may report on truncation errors */
|
||||
def finish(): ByteString
|
||||
|
||||
/** Combines decompress and finish */
|
||||
def decompressAndFinish(input: ByteString): ByteString = decompress(input) ++ finish()
|
||||
|
||||
def decoderFlow: Flow[ByteString, ByteString] =
|
||||
Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk))
|
||||
|
||||
def decode(input: ByteString): ByteString = decodeWithLimits(input)
|
||||
def decodeWithLimits(input: ByteString, maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString =
|
||||
StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get
|
||||
def decodeFromIterator(input: Iterator[ByteString], maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString =
|
||||
StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,27 +4,21 @@
|
|||
|
||||
package akka.http.coding
|
||||
|
||||
import java.io.OutputStream
|
||||
import java.util.zip.{ DataFormatException, ZipException, Inflater, Deflater }
|
||||
import java.util.zip.{ Inflater, Deflater }
|
||||
import akka.stream.stage._
|
||||
import akka.util.{ ByteStringBuilder, ByteString }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import akka.http.util._
|
||||
import akka.http.model._
|
||||
import headers.HttpEncodings
|
||||
import akka.http.model.headers.HttpEncodings
|
||||
|
||||
class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder {
|
||||
class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.deflate
|
||||
def newCompressor = new DeflateCompressor
|
||||
def newDecompressor = new DeflateDecompressor
|
||||
}
|
||||
|
||||
/**
|
||||
* An encoder and decoder for the HTTP 'deflate' encoding.
|
||||
*/
|
||||
object Deflate extends Deflate(Encoder.DefaultFilter) {
|
||||
def apply(messageFilter: HttpMessage ⇒ Boolean) = new Deflate(messageFilter)
|
||||
def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new DeflateDecompressor(maxBytesPerChunk)
|
||||
}
|
||||
object Deflate extends Deflate(Encoder.DefaultFilter)
|
||||
|
||||
class DeflateCompressor extends Compressor {
|
||||
protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false)
|
||||
|
|
@ -88,27 +82,57 @@ class DeflateCompressor extends Compressor {
|
|||
new Array[Byte](size)
|
||||
}
|
||||
|
||||
class DeflateDecompressor extends Decompressor {
|
||||
protected lazy val inflater = new Inflater()
|
||||
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater() = new Inflater()
|
||||
|
||||
def decompress(input: ByteString): ByteString =
|
||||
try {
|
||||
inflater.setInput(input.toArray)
|
||||
drain(new Array[Byte](input.length * 2))
|
||||
} catch {
|
||||
case e: DataFormatException ⇒
|
||||
throw new ZipException(e.getMessage.toOption getOrElse "Invalid ZLIB data format")
|
||||
def initial: State = StartInflate
|
||||
def afterInflate: State = StartInflate
|
||||
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
|
||||
protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.finish()
|
||||
}
|
||||
|
||||
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] {
|
||||
protected def createInflater(): Inflater
|
||||
val inflater = createInflater()
|
||||
|
||||
protected def afterInflate: State
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
|
||||
|
||||
/** Start inflating */
|
||||
case object StartInflate extends State {
|
||||
def onPush(data: ByteString, ctx: Context[ByteString]): Directive = {
|
||||
require(inflater.needsInput())
|
||||
inflater.setInput(data.toArray)
|
||||
|
||||
becomeWithRemaining(Inflate()(data), ByteString.empty, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec protected final def drain(buffer: Array[Byte], result: ByteString = ByteString.empty): ByteString = {
|
||||
val len = inflater.inflate(buffer)
|
||||
if (len > 0) drain(buffer, result ++ ByteString.fromArray(buffer, 0, len))
|
||||
else if (inflater.needsDictionary) throw new ZipException("ZLIB dictionary missing")
|
||||
else result
|
||||
/** Inflate */
|
||||
case class Inflate()(data: ByteString) extends IntermediateState {
|
||||
override def onPull(ctx: Context[ByteString]): Directive = {
|
||||
val buffer = new Array[Byte](maxBytesPerChunk)
|
||||
val read = inflater.inflate(buffer)
|
||||
if (read > 0) {
|
||||
afterBytesRead(buffer, 0, read)
|
||||
ctx.push(ByteString.fromArray(buffer, 0, read))
|
||||
} else {
|
||||
val remaining = data.takeRight(inflater.getRemaining)
|
||||
val next =
|
||||
if (inflater.finished()) afterInflate
|
||||
else StartInflate
|
||||
|
||||
becomeWithRemaining(next, remaining, ctx)
|
||||
}
|
||||
}
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): Directive =
|
||||
throw new IllegalStateException("Don't expect a new Element")
|
||||
}
|
||||
|
||||
def finish(): ByteString = {
|
||||
inflater.end()
|
||||
ByteString.empty
|
||||
def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = {
|
||||
become(next)
|
||||
if (remaining.isEmpty) current.onPull(ctx)
|
||||
else current.onPush(remaining, ctx)
|
||||
}
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ package akka.http.coding
|
|||
|
||||
import akka.http.model._
|
||||
import akka.http.util.StreamUtils
|
||||
import akka.stream.stage.Stage
|
||||
import akka.util.ByteString
|
||||
import headers._
|
||||
import akka.stream.scaladsl.Flow
|
||||
|
|
@ -21,13 +22,13 @@ trait Encoder {
|
|||
else message.self
|
||||
|
||||
def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
|
||||
mapper.transformDataBytes(t, newEncodeTransformer)
|
||||
mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer))
|
||||
|
||||
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
|
||||
|
||||
def newCompressor: Compressor
|
||||
|
||||
def newEncodeTransformer(): Flow[ByteString, ByteString] = {
|
||||
def newEncodeTransformer(): Stage[ByteString, ByteString] = {
|
||||
val compressor = newCompressor
|
||||
|
||||
def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes)
|
||||
|
|
|
|||
|
|
@ -4,19 +4,19 @@
|
|||
|
||||
package akka.http.coding
|
||||
|
||||
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage._
|
||||
|
||||
import akka.http.util.ByteReader
|
||||
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import akka.http.model._
|
||||
import headers.HttpEncodings
|
||||
|
||||
import scala.util.control.{ NonFatal, NoStackTrace }
|
||||
|
||||
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder {
|
||||
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.gzip
|
||||
def newCompressor = new GzipCompressor
|
||||
def newDecompressor = new GzipDecompressor
|
||||
def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new GzipDecompressor(maxBytesPerChunk)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -59,24 +59,29 @@ class GzipCompressor extends DeflateCompressor {
|
|||
}
|
||||
}
|
||||
|
||||
/** A suspendable gzip decompressor */
|
||||
class GzipDecompressor extends DeflateDecompressor {
|
||||
override protected lazy val inflater = new Inflater(true) // disable ZLIB headers
|
||||
override def decompress(input: ByteString): ByteString = DecompressionStateMachine.run(input)
|
||||
override def finish(): ByteString =
|
||||
if (DecompressionStateMachine.isFinished) ByteString.empty
|
||||
else fail("Truncated GZIP stream")
|
||||
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater(): Inflater = new Inflater(true)
|
||||
|
||||
import GzipDecompressor._
|
||||
def initial: State = Initial
|
||||
|
||||
object DecompressionStateMachine extends StateMachine {
|
||||
def isFinished: Boolean = currentState == finished
|
||||
/** No bytes were received yet */
|
||||
case object Initial extends State {
|
||||
def onPush(data: ByteString, ctx: Context[ByteString]): Directive =
|
||||
if (data.isEmpty) ctx.pull()
|
||||
else becomeWithRemaining(ReadHeaders, data, ctx)
|
||||
|
||||
def initialState = finished
|
||||
override def onPull(ctx: Context[ByteString]): Directive =
|
||||
if (ctx.isFinishing) {
|
||||
ctx.finish()
|
||||
} else super.onPull(ctx)
|
||||
}
|
||||
|
||||
private def readHeaders(data: ByteString): Action =
|
||||
try {
|
||||
val reader = new ByteReader(data)
|
||||
var crc32: CRC32 = new CRC32
|
||||
protected def afterInflate: State = ReadTrailer
|
||||
|
||||
/** Reading the header bytes */
|
||||
case object ReadHeaders extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): Directive = {
|
||||
import reader._
|
||||
|
||||
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
||||
|
|
@ -84,59 +89,50 @@ class GzipDecompressor extends DeflateDecompressor {
|
|||
val flags = readByte()
|
||||
skip(6) // skip MTIME, XFL and OS fields
|
||||
if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields
|
||||
if ((flags & 8) > 0) while (readByte() != 0) {} // skip optional file name
|
||||
if ((flags & 16) > 0) while (readByte() != 0) {} // skip optional file comment
|
||||
if ((flags & 2) > 0 && crc16(data.take(currentOffset)) != readShort()) fail("Corrupt GZIP header")
|
||||
|
||||
ContinueWith(deflate(new CRC32), remainingData)
|
||||
} catch {
|
||||
case ByteReader.NeedMoreData ⇒ SuspendAndRetryWithMoreData
|
||||
}
|
||||
|
||||
private def deflate(checkSum: CRC32)(data: ByteString): Action = {
|
||||
assert(inflater.needsInput())
|
||||
inflater.setInput(data.toArray)
|
||||
val output = drain(new Array[Byte](data.length * 2))
|
||||
checkSum.update(output.toArray)
|
||||
if (inflater.finished()) EmitAndContinueWith(output, readTrailer(checkSum), data.takeRight(inflater.getRemaining))
|
||||
else EmitAndSuspend(output)
|
||||
}
|
||||
|
||||
private def readTrailer(checkSum: CRC32)(data: ByteString): Action =
|
||||
try {
|
||||
val reader = new ByteReader(data)
|
||||
import reader._
|
||||
|
||||
if (readInt() != checkSum.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||
if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
|
||||
if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name
|
||||
if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment
|
||||
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShort()) fail("Corrupt GZIP header")
|
||||
|
||||
inflater.reset()
|
||||
checkSum.reset()
|
||||
ContinueWith(finished, remainingData) // start over to support multiple concatenated gzip streams
|
||||
} catch {
|
||||
case ByteReader.NeedMoreData ⇒ SuspendAndRetryWithMoreData
|
||||
crc32.reset()
|
||||
becomeWithRemaining(StartInflate, remainingData, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val finished: ByteString ⇒ Action =
|
||||
data ⇒ if (data.nonEmpty) ContinueWith(readHeaders, data) else SuspendAndRetryWithMoreData
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||
crc32.update(buffer, offset, length)
|
||||
|
||||
/** Reading the trailer */
|
||||
case object ReadTrailer extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): Directive = {
|
||||
import reader._
|
||||
|
||||
if (readInt() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||
if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
|
||||
|
||||
becomeWithRemaining(Initial, remainingData, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
|
||||
private def crc16(data: ByteString) = {
|
||||
val crc = new CRC32
|
||||
crc.update(data.toArray)
|
||||
crc.getValue.toInt & 0xFFFF
|
||||
}
|
||||
}
|
||||
|
||||
override protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.fail(new ZipException("Truncated GZIP stream"))
|
||||
|
||||
private def fail(msg: String) = throw new ZipException(msg)
|
||||
|
||||
}
|
||||
|
||||
/** INTERNAL API */
|
||||
private[http] object GzipDecompressor {
|
||||
// RFC 1952: http://tools.ietf.org/html/rfc1952 section 2.2
|
||||
val Header = ByteString(
|
||||
31, // ID1
|
||||
-117, // ID2
|
||||
0x1F, // ID1
|
||||
0x8B, // ID2
|
||||
8, // CM = Deflate
|
||||
0, // FLG
|
||||
0, // MTIME 1
|
||||
|
|
@ -146,76 +142,4 @@ private[http] object GzipDecompressor {
|
|||
0, // XFL
|
||||
0 // OS
|
||||
)
|
||||
|
||||
class ByteReader(input: ByteString) {
|
||||
import ByteReader.NeedMoreData
|
||||
|
||||
private[this] var off = 0
|
||||
|
||||
def readByte(): Int =
|
||||
if (off < input.length) {
|
||||
val x = input(off)
|
||||
off += 1
|
||||
x.toInt & 0xFF
|
||||
} else throw NeedMoreData
|
||||
def readShort(): Int = readByte() | (readByte() << 8)
|
||||
def readInt(): Int = readShort() | (readShort() << 16)
|
||||
def skip(numBytes: Int): Unit =
|
||||
if (off + numBytes <= input.length) off += numBytes
|
||||
else throw NeedMoreData
|
||||
def currentOffset: Int = off
|
||||
def remainingData: ByteString = input.drop(off)
|
||||
}
|
||||
object ByteReader {
|
||||
val NeedMoreData = new Exception with NoStackTrace
|
||||
}
|
||||
|
||||
/** A simple state machine implementation for suspendable parsing */
|
||||
trait StateMachine {
|
||||
sealed trait Action
|
||||
/** Cache the current input and suspend to wait for more data */
|
||||
case object SuspendAndRetryWithMoreData extends Action
|
||||
/** Emit some output and suspend in the current state and wait for more data */
|
||||
case class EmitAndSuspend(output: ByteString) extends Action
|
||||
/** Proceed to the nextState and immediately run it with the remainingInput */
|
||||
case class ContinueWith(nextState: State, remainingInput: ByteString) extends Action
|
||||
/** Emit some output and then proceed to the nextState and immediately run it with the remainingInput */
|
||||
case class EmitAndContinueWith(output: ByteString, nextState: State, remainingInput: ByteString) extends Action
|
||||
|
||||
type State = ByteString ⇒ Action
|
||||
def initialState: State
|
||||
|
||||
private[this] var state: State = initialState
|
||||
def currentState: State = state
|
||||
|
||||
/** Run the state machine with the current input */
|
||||
final def run(input: ByteString): ByteString = {
|
||||
@tailrec def rec(input: ByteString, result: ByteString = ByteString.empty): ByteString =
|
||||
state(input) match {
|
||||
case SuspendAndRetryWithMoreData ⇒
|
||||
val oldState = state
|
||||
state = { newData ⇒
|
||||
state = oldState
|
||||
oldState(input ++ newData)
|
||||
}
|
||||
result
|
||||
case EmitAndSuspend(output) ⇒ result ++ output
|
||||
case ContinueWith(next, remainingInput) ⇒
|
||||
state = next
|
||||
if (remainingInput.nonEmpty) rec(remainingInput, result)
|
||||
else result
|
||||
case EmitAndContinueWith(output, next, remainingInput) ⇒
|
||||
state = next
|
||||
rec(remainingInput, result ++ output)
|
||||
}
|
||||
try rec(input)
|
||||
catch {
|
||||
case NonFatal(e) ⇒
|
||||
state = failState
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
private def failState: State = _ ⇒ throw new IllegalStateException("Trying to reuse failed decompressor.")
|
||||
}
|
||||
}
|
||||
|
|
@ -5,13 +5,15 @@
|
|||
package akka.http.coding
|
||||
|
||||
import akka.http.model._
|
||||
import akka.http.util.StreamUtils
|
||||
import akka.stream.stage.Stage
|
||||
import akka.util.ByteString
|
||||
import headers.HttpEncodings
|
||||
|
||||
/**
|
||||
* An encoder and decoder for the HTTP 'identity' encoding.
|
||||
*/
|
||||
object NoCoding extends Coder {
|
||||
object NoCoding extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.identity
|
||||
|
||||
override def encode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = message.self
|
||||
|
|
@ -22,7 +24,9 @@ object NoCoding extends Coder {
|
|||
val messageFilter: HttpMessage ⇒ Boolean = _ ⇒ false
|
||||
|
||||
def newCompressor = NoCodingCompressor
|
||||
def newDecompressor = NoCodingDecompressor
|
||||
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
() ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk)
|
||||
}
|
||||
|
||||
object NoCodingCompressor extends Compressor {
|
||||
|
|
@ -33,7 +37,3 @@ object NoCodingCompressor extends Compressor {
|
|||
def compressAndFlush(input: ByteString): ByteString = input
|
||||
def compressAndFinish(input: ByteString): ByteString = input
|
||||
}
|
||||
object NoCodingDecompressor extends Decompressor {
|
||||
def decompress(input: ByteString): ByteString = input
|
||||
def finish(): ByteString = ByteString.empty
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue