diff --git a/akka-stream/src/main/scala/akka/stream/io/Framing.scala b/akka-stream/src/main/scala/akka/stream/io/Framing.scala index 09ce0bc93e..96172bc77c 100644 --- a/akka-stream/src/main/scala/akka/stream/io/Framing.scala +++ b/akka-stream/src/main/scala/akka/stream/io/Framing.scala @@ -81,59 +81,44 @@ object Framing { def simpleFramingProtocol(maximumMessageLength: Int): BidiFlow[ByteString, ByteString, ByteString, ByteString, NotUsed] = { val decoder = lengthField(4, 0, maximumMessageLength + 4, ByteOrder.BIG_ENDIAN).map(_.drop(4)) val encoder = Flow[ByteString].transform(() ⇒ new PushStage[ByteString, ByteString] { - override def onPush(message: ByteString, ctx: Context[ByteString]): SyncDirective = { - if (message.size > maximumMessageLength) - ctx.fail(new FramingException(s"Maximum allowed message size is $maximumMessageLength " + - s"but tried to send ${message.size} bytes")) + val msgSize = message.size + if (msgSize > maximumMessageLength) + ctx.fail(new FramingException(s"Maximum allowed message size is $maximumMessageLength but tried to send $msgSize bytes")) else { - val header = ByteString( - (message.size >> 24) & 0xFF, - (message.size >> 16) & 0xFF, - (message.size >> 8) & 0xFF, - message.size & 0xFF) + val header = ByteString((msgSize >> 24) & 0xFF, (msgSize >> 16) & 0xFF, (msgSize >> 8) & 0xFF, msgSize & 0xFF) ctx.push(header ++ message) } } - }) BidiFlow.fromFlowsMat(encoder, decoder)(Keep.left) } - private trait IntDecoder { - def decode(bs: ByteIterator): Int - } - class FramingException(msg: String) extends RuntimeException(msg) - private class BigEndianCodec(val length: Int) extends IntDecoder { - override def decode(bs: ByteIterator): Int = { - var count = length - var decoded = 0 - while (count > 0) { - decoded <<= 8 - decoded |= bs.next().toInt & 0xFF - count -= 1 - } - decoded + private final val bigEndianDecoder: (ByteIterator, Int) ⇒ Int = (bs, length) ⇒ { + var count = length + var decoded = 0 + while (count > 0) { + decoded <<= 8 + decoded |= bs.next().toInt & 0xFF + count -= 1 } + decoded } - private class LittleEndianCodec(val length: Int) extends IntDecoder { - private val highestOctet = (length - 1) * 8 - private val Mask = ((1L << (length * 8)) - 1).toInt - - override def decode(bs: ByteIterator): Int = { - var count = length - var decoded = 0 - while (count > 0) { - decoded >>>= 8 - decoded += (bs.next().toInt & 0xFF) << highestOctet - count -= 1 - } - decoded & Mask + private final val littleEndianDecoder: (ByteIterator, Int) ⇒ Int = (bs, length) ⇒ { + val highestOctet = (length - 1) << 3 + val Mask = ((1L << (length << 3)) - 1).toInt + var count = length + var decoded = 0 + while (count > 0) { + decoded >>>= 8 + decoded += (bs.next().toInt & 0xFF) << highestOctet + count -= 1 } + decoded & Mask } private class DelimiterFramingStage(val separatorBytes: ByteString, val maximumLineBytes: Int, val allowTruncation: Boolean) @@ -141,21 +126,17 @@ object Framing { private val firstSeparatorByte = separatorBytes.head private var buffer = ByteString.empty private var nextPossibleMatch = 0 - private var finishing = false override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { buffer ++= chunk doParse(ctx) } - override def onPull(ctx: Context[ByteString]): SyncDirective = { - doParse(ctx) - } + override def onPull(ctx: Context[ByteString]): SyncDirective = doParse(ctx) - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = if (buffer.nonEmpty) ctx.absorbTermination() else ctx.finish() - } private def tryPull(ctx: Context[ByteString]): SyncDirective = { if (ctx.isFinishing) { @@ -184,8 +165,7 @@ object Framing { nextPossibleMatch = possibleMatchPos tryPull(ctx) } else { - if (buffer.slice(possibleMatchPos, possibleMatchPos + separatorBytes.size) - == separatorBytes) { + if (buffer.slice(possibleMatchPos, possibleMatchPos + separatorBytes.size) == separatorBytes) { // Found a match val parsedFrame = buffer.slice(0, possibleMatchPos).compact buffer = buffer.drop(possibleMatchPos + separatorBytes.size) @@ -203,59 +183,51 @@ object Framing { override def postStop(): Unit = buffer = null } - private class LengthFieldFramingStage( + private final class LengthFieldFramingStage( val lengthFieldLength: Int, val lengthFieldOffset: Int, val maximumFrameLength: Int, val byteOrder: ByteOrder) extends PushPullStage[ByteString, ByteString] { private var buffer = ByteString.empty - private val minimumChunkSize = lengthFieldOffset + lengthFieldLength - private val intDecoder: IntDecoder = byteOrder match { - case ByteOrder.BIG_ENDIAN ⇒ new BigEndianCodec(lengthFieldLength) - case ByteOrder.LITTLE_ENDIAN ⇒ new LittleEndianCodec(lengthFieldLength) - } private var frameSize = Int.MaxValue - - private def parseLength: Int = intDecoder.decode(buffer.iterator.drop(lengthFieldOffset)) - - private def tryPull(ctx: Context[ByteString]): SyncDirective = { - if (ctx.isFinishing) ctx.fail(new FramingException( - "Stream finished but there was a truncated final frame in the buffer")) - else ctx.pull() + private val minimumChunkSize = lengthFieldOffset + lengthFieldLength + private val intDecoder = byteOrder match { + case ByteOrder.BIG_ENDIAN ⇒ bigEndianDecoder + case ByteOrder.LITTLE_ENDIAN ⇒ littleEndianDecoder } + private def tryPull(ctx: Context[ByteString]): SyncDirective = + if (ctx.isFinishing) ctx.fail(new FramingException("Stream finished but there was a truncated final frame in the buffer")) + else ctx.pull() + override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { buffer ++= chunk doParse(ctx) } - override def onPull(ctx: Context[ByteString]): SyncDirective = { - doParse(ctx) - } + override def onPull(ctx: Context[ByteString]): SyncDirective = doParse(ctx) - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = if (buffer.nonEmpty) ctx.absorbTermination() else ctx.finish() - } - - private def emitFrame(ctx: Context[ByteString]): SyncDirective = { - val parsedFrame = buffer.take(frameSize).compact - buffer = buffer.drop(frameSize) - frameSize = Int.MaxValue - if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame) - else ctx.push(parsedFrame) - } private def doParse(ctx: Context[ByteString]): SyncDirective = { - if (buffer.size >= frameSize) { - emitFrame(ctx) - } else if (buffer.size >= minimumChunkSize) { - frameSize = parseLength + minimumChunkSize + def emitFrame(ctx: Context[ByteString]): SyncDirective = { + val parsedFrame = buffer.take(frameSize).compact + buffer = buffer.drop(frameSize) + frameSize = Int.MaxValue + if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame) + else ctx.push(parsedFrame) + } + + val bufSize = buffer.size + if (bufSize >= frameSize) emitFrame(ctx) + else if (bufSize >= minimumChunkSize) { + val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength) + frameSize = parsedLength + minimumChunkSize if (frameSize > maximumFrameLength) - ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength " + - s"but decoded frame header reported size $frameSize")) - else if (buffer.size >= frameSize) - emitFrame(ctx) + ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize")) + else if (bufSize >= frameSize) emitFrame(ctx) else tryPull(ctx) } else tryPull(ctx) }