diff --git a/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala index e7bb9b8970..fd23cc0dcf 100644 --- a/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala @@ -10,7 +10,7 @@ import akka.http.util._ import HttpMethods.POST import java.io.{ InputStream, OutputStream, ByteArrayInputStream, ByteArrayOutputStream } -import java.util.zip.{ DataFormatException, GZIPInputStream, GZIPOutputStream } +import java.util.zip.{ ZipException, DataFormatException, GZIPInputStream, GZIPOutputStream } import akka.util.ByteString @@ -51,6 +51,10 @@ class GzipSpec extends WordSpec with CodecSpecSupport { val ex = the[DataFormatException] thrownBy ourGunzip(corruptGzipContent) ex.getMessage should equal("invalid literal/length code") } + "throw early if header is corrupt" in { + val ex = the[ZipException] thrownBy ourGunzip(ByteString(0, 1, 2, 3, 4)) + ex.getMessage should equal("Not in GZIP format") + } "not throw an error if a subsequent block is corrupt" in { pending // FIXME: should we read as long as possible and only then report an error, that seems somewhat arbitrary ourGunzip(Seq(gzip("Hello,"), gzip(" dear "), corruptGzipContent).join) should readAs("Hello, dear ") diff --git a/akka-http/src/main/scala/akka/http/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/coding/Gzip.scala index 7cffe6fe62..d61797d77c 100644 --- a/akka-http/src/main/scala/akka/http/coding/Gzip.scala +++ b/akka-http/src/main/scala/akka/http/coding/Gzip.scala @@ -4,7 +4,6 @@ package akka.http.coding -import java.io.OutputStream import java.util.zip.{ Inflater, CRC32, ZipException, Deflater } import akka.util.ByteString @@ -12,7 +11,7 @@ import scala.annotation.tailrec import akka.http.model._ import headers.HttpEncodings -import scala.util.control.NoStackTrace +import scala.util.control.{ NonFatal, NoStackTrace } class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Decoder with Encoder { val encoding = HttpEncodings.gzip @@ -70,9 +69,7 @@ class GzipDecompressor extends DeflateDecompressor { def initialState = readHeaders private def readHeaders(data: ByteString): Action = - // header has at least size 3 - if (data.size < 4) SuspendAndRetryWithMoreData - else try { + try { val reader = new ByteReader(data) import reader._ @@ -114,7 +111,7 @@ class GzipDecompressor extends DeflateDecompressor { case ByteReader.NeedMoreData ⇒ SuspendAndRetryWithMoreData } - private def fail(msg: String) = Fail(new ZipException(msg)) + private def fail(msg: String) = throw new ZipException(msg) private def crc16(data: ByteString) = { val crc = new CRC32 @@ -174,8 +171,6 @@ private[http] object GzipDecompressor { case class ContinueWith(nextState: State, remainingInput: ByteString) extends Action /** Emit some output and then proceed to the nextState and immediately run it with the remainingInput */ case class EmitAndContinueWith(output: ByteString, nextState: State, remainingInput: ByteString) extends Action - /** Fail with the given exception and go into the failed state which will throw for any new data */ - case class Fail(cause: Throwable) extends Action type State = ByteString ⇒ Action def initialState: State @@ -183,26 +178,31 @@ private[http] object GzipDecompressor { private[this] var state: State = initialState /** Run the state machine with the current input */ - @tailrec final def run(input: ByteString, result: ByteString = ByteString.empty): ByteString = - state(input) match { - case SuspendAndRetryWithMoreData ⇒ - val oldState = state - state = { newData ⇒ - state = oldState - oldState(input ++ newData) - } - result - case EmitAndSuspend(output) ⇒ result ++ output - case ContinueWith(next, remainingInput) ⇒ - state = next - run(remainingInput, result) - case EmitAndContinueWith(output, next, remainingInput) ⇒ - state = next - run(remainingInput, result ++ output) - case Fail(cause) ⇒ + final def run(input: ByteString): ByteString = { + @tailrec def rec(input: ByteString, result: ByteString = ByteString.empty): ByteString = + state(input) match { + case SuspendAndRetryWithMoreData ⇒ + val oldState = state + state = { newData ⇒ + state = oldState + oldState(input ++ newData) + } + result + case EmitAndSuspend(output) ⇒ result ++ output + case ContinueWith(next, remainingInput) ⇒ + state = next + rec(remainingInput, result) + case EmitAndContinueWith(output, next, remainingInput) ⇒ + state = next + rec(remainingInput, result ++ output) + } + try rec(input) + catch { + case NonFatal(e) ⇒ state = failState - throw cause + throw e } + } private def failState: State = _ ⇒ throw new IllegalStateException("Trying to reuse failed decompressor.") }