From 6318f3e9726ead3c9a3d6579b9e76bfa23d5b7f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Endre=20S=C3=A1ndor=20Varga?= Date: Tue, 5 May 2015 15:02:11 +0200 Subject: [PATCH] +str #17310: Basic framing support --- akka-docs-dev/rst/java/stream-io.rst | 5 +- .../stream/cookbook/RecipeParseLines.scala | 62 +--- .../docs/stream/io/StreamTcpDocSpec.scala | 11 +- akka-docs-dev/rst/scala/stream-cookbook.rst | 18 +- akka-docs-dev/rst/scala/stream-io.rst | 5 +- .../scala/akka/stream/io/FramingSpec.scala | 229 ++++++++++++++ .../main/scala/akka/stream/io/Framing.scala | 285 ++++++++++++++++++ .../scala/akka/stream/javadsl/BidiFlow.scala | 31 +- .../scala/akka/stream/scaladsl/BidiFlow.scala | 28 ++ 9 files changed, 589 insertions(+), 85 deletions(-) create mode 100644 akka-stream-tests/src/test/scala/akka/stream/io/FramingSpec.scala create mode 100644 akka-stream/src/main/scala/akka/stream/io/Framing.scala diff --git a/akka-docs-dev/rst/java/stream-io.rst b/akka-docs-dev/rst/java/stream-io.rst index 9abf926156..6a43e5746e 100644 --- a/akka-docs-dev/rst/java/stream-io.rst +++ b/akka-docs-dev/rst/java/stream-io.rst @@ -23,8 +23,9 @@ which will emit an :class:`IncomingConnection` element for each new connection t Next, we simply handle *each* incoming connection using a :class:`Flow` which will be used as the processing stage to handle and emit ByteStrings from and to the TCP Socket. Since one :class:`ByteString` does not have to necessarily -correspond to exactly one line of text (the client might be sending the line in chunks) we use the ``parseLines`` -recipe from the :ref:`cookbook-parse-lines-java` Akka Streams Cookbook recipe to chunk the inputs up into actual lines of text. +correspond to exactly one line of text (the client might be sending the line in chunks) we use the ``lines`` +helper Flow from ``akka.stream.io.Framing`` to chunk the inputs up into actual lines of text. The last boolean +argument indicates that we require an explicit line ending even for the last message before the connection is closed. In this example we simply add exclamation marks to each incoming text message and push it through the flow: .. includecode:: ../../../akka-samples/akka-docs-java-lambda/src/test/java/docs/stream/io/StreamTcpDocTest.java#echo-server-simple-handle diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeParseLines.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeParseLines.scala index 826d83912f..f469aebca3 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeParseLines.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeParseLines.scala @@ -20,9 +20,10 @@ class RecipeParseLines extends RecipeSpec { ByteString("\nHello Akka!\r\nHello Streams!"), ByteString("\r\n\r\n"))) - import RecipeParseLines._ - - val linesStream = rawData.transform(() => parseLines("\r\n", 100)) + //#parse-lines + import akka.stream.io.Framing + val linesStream = rawData.via(Framing.lines("\r\n", maximumLineBytes = 100)) + //#parse-lines Await.result(linesStream.grouped(10).runWith(Sink.head), 3.seconds) should be(List( "Hello World\r!", @@ -34,58 +35,3 @@ class RecipeParseLines extends RecipeSpec { } } - -object RecipeParseLines { - - import akka.stream.stage._ - - //#parse-lines - def parseLines(separator: String, maximumLineBytes: Int) = - new StatefulStage[ByteString, String] { - private val separatorBytes = ByteString(separator) - private val firstSeparatorByte = separatorBytes.head - private var buffer = ByteString.empty - private var nextPossibleMatch = 0 - - def initial = new State { - override def onPush(chunk: ByteString, ctx: Context[String]): SyncDirective = { - buffer ++= chunk - if (buffer.size > maximumLineBytes) - ctx.fail(new IllegalStateException(s"Read ${buffer.size} bytes " + - s"which is more than $maximumLineBytes without seeing a line terminator")) - else emit(doParse(Vector.empty).iterator, ctx) - } - - @tailrec - private def doParse(parsedLinesSoFar: Vector[String]): Vector[String] = { - val possibleMatchPos = buffer.indexOf(firstSeparatorByte, from = nextPossibleMatch) - if (possibleMatchPos == -1) { - // No matching character, we need to accumulate more bytes into the buffer - nextPossibleMatch = buffer.size - parsedLinesSoFar - } else if (possibleMatchPos + separatorBytes.size > buffer.size) { - // We have found a possible match (we found the first character of the terminator - // sequence) but we don't have yet enough bytes. We remember the position to - // retry from next time. - nextPossibleMatch = possibleMatchPos - parsedLinesSoFar - } else { - if (buffer.slice(possibleMatchPos, possibleMatchPos + separatorBytes.size) - == separatorBytes) { - // Found a match - val parsedLine = buffer.slice(0, possibleMatchPos).utf8String - buffer = buffer.drop(possibleMatchPos + separatorBytes.size) - nextPossibleMatch -= possibleMatchPos + separatorBytes.size - doParse(parsedLinesSoFar :+ parsedLine) - } else { - nextPossibleMatch += 1 - doParse(parsedLinesSoFar) - } - } - - } - } - } - //#parse-lines - -} diff --git a/akka-docs-dev/rst/scala/code/docs/stream/io/StreamTcpDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/io/StreamTcpDocSpec.scala index c7de46f93c..34e581b8a0 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/io/StreamTcpDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/io/StreamTcpDocSpec.scala @@ -41,11 +41,13 @@ class StreamTcpDocSpec extends AkkaSpec { Tcp().bind(localhost.getHostName, localhost.getPort) // TODO getHostString in Java7 //#echo-server-simple-handle + import akka.stream.io.Framing + connections runForeach { connection => println(s"New connection from: ${connection.remoteAddress}") val echo = Flow[ByteString] - .transform(() => RecipeParseLines.parseLines("\n", maximumLineBytes = 256)) + .via(Framing.lines("\n", maximumLineBytes = 256, allowTruncation = false)) .map(_ + "!!!\n") .map(ByteString(_)) @@ -60,7 +62,9 @@ class StreamTcpDocSpec extends AkkaSpec { val connections = Tcp().bind(localhost.getHostName, localhost.getPort) // TODO getHostString in Java7 val serverProbe = TestProbe() + import akka.stream.io.Framing //#welcome-banner-chat-server + connections runForeach { connection => val serverLogic = Flow() { implicit b => @@ -81,7 +85,7 @@ class StreamTcpDocSpec extends AkkaSpec { val welcome = Source.single(ByteString(welcomeMsg)) val echo = b.add(Flow[ByteString] - .transform(() => RecipeParseLines.parseLines("\n", maximumLineBytes = 256)) + .via(Framing.lines("\n", maximumLineBytes = 256, allowTruncation = false)) //#welcome-banner-chat-server .map { command ⇒ serverProbe.ref ! command; command } //#welcome-banner-chat-server @@ -101,6 +105,7 @@ class StreamTcpDocSpec extends AkkaSpec { connection.handleWith(serverLogic) } + import akka.stream.io.Framing //#welcome-banner-chat-server val input = new AtomicReference("Hello world" :: "What a lovely day" :: Nil) @@ -131,7 +136,7 @@ class StreamTcpDocSpec extends AkkaSpec { } val repl = Flow[ByteString] - .transform(() => RecipeParseLines.parseLines("\n", maximumLineBytes = 256)) + .via(Framing.lines("\n", maximumLineBytes = 256, allowTruncation = false)) .map(text => println("Server: " + text)) .map(_ => readLine("> ")) .transform(() ⇒ replParser) diff --git a/akka-docs-dev/rst/scala/stream-cookbook.rst b/akka-docs-dev/rst/scala/stream-cookbook.rst index e98d49d75e..2e5c63aa4c 100644 --- a/akka-docs-dev/rst/scala/stream-cookbook.rst +++ b/akka-docs-dev/rst/scala/stream-cookbook.rst @@ -96,22 +96,8 @@ Parsing lines from a stream of ByteStrings characters (or, alternatively, containing binary frames delimited by a special delimiter byte sequence) which needs to be parsed. -We express our solution as a :class:`StatefulStage` because it has support for emitting multiple elements easily -through its ``emit(iterator, ctx)`` helper method. Since an incoming ByteString chunk might contain multiple lines (frames) -this feature comes in handy. - -To create the parser we only need to hook into the ``onPush`` handler. We maintain a buffer of bytes (expressed as -a :class:`ByteString`) by simply concatenating incoming chunks with it. Since we don't want to allow unbounded size -lines (records) we always check if the buffer size is larger than the allowed ``maximumLineBytes`` value, and terminate -the stream if this invariant is violated. - -After we updated the buffer, we try to find the terminator sequence as a subsequence of the current buffer. To be -efficient, we also maintain a pointer ``nextPossibleMatch`` into the buffer so that we only search that part of the -buffer where new matches are possible. - -The search for a match is done in two steps: first we try to search for the first character of the terminator sequence -in the buffer. If we find a match, we do a full subsequence check to see if we had a false positive or not. The parsing -logic is recursive to be able to parse multiple lines (records) contained in the decoding buffer. +The :class:`Framing` helper object contains a convenience method to parse messages from a stream of ``ByteStrings`` +and in particular it has basic support for parsing text lines: .. includecode:: code/docs/stream/cookbook/RecipeParseLines.scala#parse-lines diff --git a/akka-docs-dev/rst/scala/stream-io.rst b/akka-docs-dev/rst/scala/stream-io.rst index b1e5a05302..4cc604175a 100644 --- a/akka-docs-dev/rst/scala/stream-io.rst +++ b/akka-docs-dev/rst/scala/stream-io.rst @@ -23,8 +23,9 @@ which will emit an :class:`IncomingConnection` element for each new connection t Next, we simply handle *each* incoming connection using a :class:`Flow` which will be used as the processing stage to handle and emit ByteStrings from and to the TCP Socket. Since one :class:`ByteString` does not have to necessarily -correspond to exactly one line of text (the client might be sending the line in chunks) we use the ``parseLines`` -recipe from the :ref:`cookbook-parse-lines-scala` Akka Streams Cookbook recipe to chunk the inputs up into actual lines of text. +correspond to exactly one line of text (the client might be sending the line in chunks) we use the ``Framing.lines`` +helper Flow to chunk the inputs up into actual lines of text. The last boolean +argument indicates that we require an explicit line ending even for the last message before the connection is closed. In this example we simply add exclamation marks to each incoming text message and push it through the flow: .. includecode:: code/docs/stream/io/StreamTcpDocSpec.scala#echo-server-simple-handle diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/FramingSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/FramingSpec.scala new file mode 100644 index 0000000000..df0f2d6051 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/io/FramingSpec.scala @@ -0,0 +1,229 @@ +/** + * Copyright (C) 2014-2015 Typesafe Inc. + */ +package akka.stream.io + +import java.nio.ByteOrder + +import akka.stream.io.Framing.FramingException +import akka.stream.{ ActorFlowMaterializer, ActorFlowMaterializerSettings } +import akka.stream.scaladsl._ +import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, PushPullStage } +import akka.stream.testkit.AkkaSpec +import akka.util.{ ByteString, ByteStringBuilder } + +import scala.collection.immutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.forkjoin.ThreadLocalRandom +import scala.util.Random + +class FramingSpec extends AkkaSpec { + + val settings = ActorFlowMaterializerSettings(system) + implicit val materializer = ActorFlowMaterializer(settings) + + class Rechunker extends PushPullStage[ByteString, ByteString] { + private var rechunkBuffer = ByteString.empty + + override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { + rechunkBuffer ++= chunk + rechunk(ctx) + } + + override def onPull(ctx: Context[ByteString]): SyncDirective = { + rechunk(ctx) + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + if (rechunkBuffer.isEmpty) ctx.finish() + else ctx.absorbTermination() + } + + private def rechunk(ctx: Context[ByteString]): SyncDirective = { + if (!ctx.isFinishing && ThreadLocalRandom.current().nextBoolean()) ctx.pull() + else { + val nextChunkSize = + if (rechunkBuffer.isEmpty) 0 + else ThreadLocalRandom.current().nextInt(0, rechunkBuffer.size + 1) + val newChunk = rechunkBuffer.take(nextChunkSize) + rechunkBuffer = rechunkBuffer.drop(nextChunkSize) + if (ctx.isFinishing && rechunkBuffer.isEmpty) ctx.pushAndFinish(newChunk) + else ctx.push(newChunk) + } + } + } + + val rechunk = Flow[ByteString].transform(() ⇒ new Rechunker).named("rechunker") + + "Delimiter bytes based framing" must { + + val delimiterBytes = List("\n", "\r\n", "FOO").map(ByteString(_)) + val baseTestSequences = List("", "foo", "hello world").map(ByteString(_)) + + def completeTestSequences(delimiter: ByteString): immutable.Iterable[ByteString] = + for (prefix ← 0 until delimiter.size; s ← baseTestSequences) + yield delimiter.take(prefix) ++ s + + "work with various delimiters and test sequences" in { + for (delimiter ← delimiterBytes; _ ← 1 to 100) { + val f = Source(completeTestSequences(delimiter)) + .map(_ ++ delimiter) + .via(rechunk) + .via(Framing.delimiter(delimiter, 256)) + .grouped(1000) + .runWith(Sink.head) + + Await.result(f, 3.seconds) should be(completeTestSequences(delimiter)) + } + } + + "Respect maximum line settings" in { + // The buffer will contain more than 1 bytes, but the individual frames are less + Await.result( + Source.single(ByteString("a\nb\nc\nd\n")).via(Framing.lines("\n", 1)).grouped(100).runWith(Sink.head), + 3.seconds) should ===(List("a", "b", "c", "d")) + + an[FramingException] should be thrownBy { + Await.result( + Source.single(ByteString("ab\n")).via(Framing.lines("\n", 1)).grouped(100).runWith(Sink.head), + 3.seconds) + } + } + + "work with empty streams" in { + Await.result( + Source.empty.via(Framing.lines("\n", 256)).runFold(Vector.empty[String])(_ :+ _), + 3.seconds) should ===(Vector.empty) + } + + "report truncated frames" in { + an[FramingException] should be thrownBy { + Await.result( + Source.single(ByteString("I have no end")) + .via(Framing.lines("\n", 256, allowTruncation = false)) + .grouped(1000) + .runWith(Sink.head), + 3.seconds) + } + } + + "allow truncated frames if configured so" in { + Await.result( + Source.single(ByteString("I have no end")) + .via(Framing.lines("\n", 256, allowTruncation = true)) + .grouped(1000) + .runWith(Sink.head), + 3.seconds) should ===(List("I have no end")) + } + + } + + "Length field based framing" must { + + val referenceChunk = ByteString(scala.util.Random.nextString(0x100001)) + + val byteOrders = List(ByteOrder.BIG_ENDIAN, ByteOrder.LITTLE_ENDIAN) + val frameLengths = List(0, 1, 2, 3, 0xFF, 0x100, 0x101, 0xFFF, 0x1000, 0x1001, 0xFFFF, 0x100001) + val fieldLengths = List(1, 2, 3, 4) + val fieldOffsets = List(0, 1, 2, 3, 15, 16, 31, 32, 44, 107) + + def encode(payload: ByteString, fieldOffset: Int, fieldLength: Int, byteOrder: ByteOrder): ByteString = { + val header = { + val h = (new ByteStringBuilder).putInt(payload.size)(byteOrder).result() + byteOrder match { + case ByteOrder.LITTLE_ENDIAN ⇒ h.take(fieldLength) + case ByteOrder.BIG_ENDIAN ⇒ h.drop(4 - fieldLength) + } + } + + ByteString(Array.ofDim[Byte](fieldOffset)) ++ header ++ payload + } + + "work with various byte orders, frame lengths and offsets" in { + for { + _ ← 1 to 10 + byteOrder ← byteOrders + fieldOffset ← fieldOffsets + fieldLength ← fieldLengths + } { + + val encodedFrames = frameLengths.filter(_ < (1 << (fieldLength * 8))).map { length ⇒ + val payload = referenceChunk.take(length) + encode(payload, fieldOffset, fieldLength, byteOrder) + } + + Await.result( + Source(encodedFrames) + .via(rechunk) + .via(Framing.lengthField(fieldLength, fieldOffset, Int.MaxValue, byteOrder)) + .grouped(10000) + .runWith(Sink.head), + 3.seconds) should ===(encodedFrames) + } + + } + + "work with empty streams" in { + Await.result( + Source.empty.via(Framing.lengthField(4, 0, Int.MaxValue, ByteOrder.BIG_ENDIAN)).runFold(Vector.empty[ByteString])(_ :+ _), + 3.seconds) should ===(Vector.empty) + } + + "report oversized frames" in { + an[FramingException] should be thrownBy { + Await.result( + Source.single(encode(referenceChunk.take(100), 0, 1, ByteOrder.BIG_ENDIAN)) + .via(Framing.lengthField(1, 0, 99, ByteOrder.BIG_ENDIAN)).runFold(Vector.empty[ByteString])(_ :+ _), + 3.seconds) + } + + an[FramingException] should be thrownBy { + Await.result( + Source.single(encode(referenceChunk.take(100), 49, 1, ByteOrder.BIG_ENDIAN)) + .via(Framing.lengthField(1, 0, 100, ByteOrder.BIG_ENDIAN)).runFold(Vector.empty[ByteString])(_ :+ _), + 3.seconds) + } + } + + "report truncated frames" in { + for { + //_ ← 1 to 10 + byteOrder ← byteOrders + fieldOffset ← fieldOffsets + fieldLength ← fieldLengths + frameLength ← frameLengths if frameLength < (1 << (fieldLength * 8)) && (frameLength != 0) + } { + + val fullFrame = encode(referenceChunk.take(frameLength), fieldOffset, fieldLength, byteOrder) + val partialFrame = fullFrame.dropRight(1) + + an[FramingException] should be thrownBy { + Await.result( + Source(List(fullFrame, partialFrame)) + .via(rechunk) + .via(Framing.lengthField(fieldLength, fieldOffset, Int.MaxValue, byteOrder)) + .grouped(10000) + .runWith(Sink.head), + 3.seconds) + } + } + } + + "support simple framing adapter" in { + val rechunkBidi = BidiFlow.wrap(rechunk, rechunk)(Keep.left) + val codecFlow = + Framing.simpleFramingProtocol(1024) + .atop(rechunkBidi) + .atop(Framing.simpleFramingProtocol(1024).reversed) + .join(Flow[ByteString]) // Loopback + + val testMessages = List.fill(100)(referenceChunk.take(Random.nextInt(1024))) + Await.result( + Source(testMessages).via(codecFlow).grouped(1000).runWith(Sink.head), + 3.seconds) should ===(testMessages) + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/io/Framing.scala b/akka-stream/src/main/scala/akka/stream/io/Framing.scala new file mode 100644 index 0000000000..4331bc13fb --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/Framing.scala @@ -0,0 +1,285 @@ +/** + * Copyright (C) 2014-2015 Typesafe Inc. + */ +package akka.stream.io + +import java.nio.ByteOrder + +import akka.stream.scaladsl.{ Keep, BidiFlow, Flow } +import akka.stream.stage._ +import akka.util.{ ByteIterator, ByteStringBuilder, ByteString } + +import scala.annotation.tailrec + +object Framing { + + /** + * Creates a Flow that handles decoding a stream of unstructured byte chunks into a stream of frames where the + * incoming chunk stream uses a specific byte-sequence to mark frame boundaries. + * + * The decoded frames will include the separator sequence. If this is not desired, this Flow can be augmented with a + * simple ''map'' operation that removes this separator. + * + * If there are buffered bytes (an incomplete frame) when the input stream finishes and ''allowTruncation'' is set to + * false then this Flow will fail the stream reporting a truncated frame. + * + * @param delimiter The byte sequence to be treated as the end of the frame. + * @param allowTruncation If turned on, then when the last frame being decoded contains no valid delimiter this Flow + * fails the stream instead of returning a truncated frame. + * @param maximumFrameLength The maximum length of allowed frames while decoding. If the maximum length is + * exceeded this Flow will fail the stream. + * @return + */ + def delimiter(delimiter: ByteString, maximumFrameLength: Int, allowTruncation: Boolean = false): Flow[ByteString, ByteString, Unit] = + Flow[ByteString].transform(() ⇒ new DelimiterFramingStage(delimiter, maximumFrameLength, allowTruncation)) + .named("delimiterFraming") + + /** + * A convenience wrapper on top of [[Framing#delimiter]] using ''String'' as the output and separator sequence types. + * Returns a Flow that decodes an unstructured input stream of byte chunks, decoding them to Strings using a separator + * String as end-of-line marker. + * + * This decoder stage treats decoded frames as simple byte sequences, converting to UTF-8 only after the frame + * boundary has been found. This means that this is not a fully UTF-8 compliant line parser. + * + * @param delimiter String to be used as a delimiter. Be aware that not all UTF-8 strings are safe to use as a + * delimiter when the input bytes are UTF-8 encoded. + * @param allowTruncation If turned on, then when the last string being decoded contains no valid delimiter this Flow + * fails the stream instead of returning a truncated string. + * @param maximumLineBytes + * The maximum allowed length for decoded strings in bytes (not in characters). + * @return + */ + def lines(delimiter: String, maximumLineBytes: Int, allowTruncation: Boolean = true): Flow[ByteString, String, Unit] = + Framing.delimiter(ByteString(delimiter), maximumLineBytes, allowTruncation).map(_.utf8String) + .named("lineFraming") + + /** + * Creates a Flow that decodes an incoming stream of unstructured byte chunks into a stream of frames, assuming that + * incoming frames have a field that encodes their length. + * + * If the input stream finishes before the last frame has been fully decoded this Flow will fail the stream reporting + * a truncated frame. + * + * @param fieldLength The length of the "size" field in bytes + * @param fieldOffset The offset of the field from the beginning of the frame in bytes + * @param maximumFrameLength The maximum length of allowed frames while decoding. If the maximum length is exceeded + * this Flow will fail the stream. This length *includes* the header (i.e the offset and + * the length of the size field) + * @param byteOrder The ''ByteOrder'' to be used when decoding the field + * @return + */ + def lengthField(fieldLength: Int, + fieldOffset: Int = 0, + maximumFrameLength: Int, + byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN): Flow[ByteString, ByteString, Unit] = { + require(fieldLength >= 1 && fieldLength <= 4, "Length field length must be 1, 2, 3 or 4.") + Flow[ByteString].transform(() ⇒ new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder)) + .named("lengthFieldFraming") + } + + /** + * Returns a BidiFlow that implements a simple framing protocol. This is a convenience wrapper over [[Framing#lengthField]] + * and simply attaches a length field header of four bytes (using big endian encoding) to outgoing messages, and decodes + * such messages in the inbound direction. The decoded messages do not contain the header. + * + * This BidiFlow is useful if a simple message framing protocol is needed (for example when TCP is used to send + * individual messages) but no compatibility with existing protocols is necessary. + * + * The encoded frames have the layout + * {{{ + * [4 bytes length field, Big Endian][User Payload] + * }}} + * The length field encodes the length of the user payload excluding the header itself. + * + * @param maximumMessageLength Maximum length of allowed messages. If sent or received messages exceed the configured + * limit this BidiFlow will fail the stream. The header attached by this BidiFlow are not + * included in this limit. + * @return + */ + def simpleFramingProtocol(maximumMessageLength: Int): BidiFlow[ByteString, ByteString, ByteString, ByteString, Unit] = { + val decoder = lengthField(4, 0, maximumMessageLength + 4, ByteOrder.BIG_ENDIAN).map(_.drop(4)) + val encoder = Flow[ByteString].transform(() ⇒ new PushStage[ByteString, ByteString] { + + override def onPush(message: ByteString, ctx: Context[ByteString]): SyncDirective = { + if (message.size > maximumMessageLength) + ctx.fail(new FramingException(s"Maximum allowed message size is $maximumMessageLength " + + s"but tried to send ${message.size} bytes")) + else { + val header = ByteString( + (message.size >> 24) & 0xFF, + (message.size >> 16) & 0xFF, + (message.size >> 8) & 0xFF, + message.size & 0xFF) + ctx.push(header ++ message) + } + } + + }) + + BidiFlow.wrap(encoder, decoder)(Keep.left) + } + + private trait IntDecoder { + def decode(bs: ByteIterator): Int + } + + class FramingException(msg: String) extends RuntimeException(msg) + + private class BigEndianCodec(val length: Int) extends IntDecoder { + override def decode(bs: ByteIterator): Int = { + var count = length + var decoded = 0 + while (count > 0) { + decoded <<= 8 + decoded |= bs.next().toInt & 0xFF + count -= 1 + } + decoded + } + } + + private class LittleEndianCodec(val length: Int) extends IntDecoder { + private val highestOctet = (length - 1) * 8 + private val Mask = (1 << (length * 8)) - 1 + + override def decode(bs: ByteIterator): Int = { + var count = length + var decoded = 0 + while (count > 0) { + decoded >>>= 8 + decoded += (bs.next().toInt & 0xFF) << highestOctet + count -= 1 + } + decoded & Mask + } + } + + private class DelimiterFramingStage(val separatorBytes: ByteString, val maximumLineBytes: Int, val allowTruncation: Boolean) + extends PushPullStage[ByteString, ByteString] { + private val firstSeparatorByte = separatorBytes.head + private var buffer = ByteString.empty + private var nextPossibleMatch = 0 + private var finishing = false + + override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { + buffer ++= chunk + doParse(ctx) + } + + override def onPull(ctx: Context[ByteString]): SyncDirective = { + doParse(ctx) + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + if (buffer.nonEmpty) ctx.absorbTermination() + else ctx.finish() + } + + private def tryPull(ctx: Context[ByteString]): SyncDirective = { + if (ctx.isFinishing) { + if (allowTruncation) ctx.pushAndFinish(buffer) + else + ctx.fail(new FramingException( + "Stream finished but there was a truncated final frame in the buffer")) + } else ctx.pull() + } + + @tailrec + private def doParse(ctx: Context[ByteString]): SyncDirective = { + val possibleMatchPos = buffer.indexOf(firstSeparatorByte, from = nextPossibleMatch) + if (possibleMatchPos > maximumLineBytes) + ctx.fail(new FramingException(s"Read ${buffer.size} bytes " + + s"which is more than $maximumLineBytes without seeing a line terminator")) + else { + if (possibleMatchPos == -1) { + // No matching character, we need to accumulate more bytes into the buffer + nextPossibleMatch = buffer.size + tryPull(ctx) + } else if (possibleMatchPos + separatorBytes.size > buffer.size) { + // We have found a possible match (we found the first character of the terminator + // sequence) but we don't have yet enough bytes. We remember the position to + // retry from next time. + nextPossibleMatch = possibleMatchPos + tryPull(ctx) + } else { + if (buffer.slice(possibleMatchPos, possibleMatchPos + separatorBytes.size) + == separatorBytes) { + // Found a match + val parsedFrame = buffer.slice(0, possibleMatchPos).compact + buffer = buffer.drop(possibleMatchPos + separatorBytes.size) + nextPossibleMatch = 0 + if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame) + else ctx.push(parsedFrame) + } else { + nextPossibleMatch += 1 + doParse(ctx) + } + } + } + } + + override def postStop(): Unit = buffer = null + } + + private class LengthFieldFramingStage( + val lengthFieldLength: Int, + val lengthFieldOffset: Int, + val maximumFrameLength: Int, + val byteOrder: ByteOrder) extends PushPullStage[ByteString, ByteString] { + private var buffer = ByteString.empty + private val minimumChunkSize = lengthFieldOffset + lengthFieldLength + private val intDecoder: IntDecoder = byteOrder match { + case ByteOrder.BIG_ENDIAN ⇒ new BigEndianCodec(lengthFieldLength) + case ByteOrder.LITTLE_ENDIAN ⇒ new LittleEndianCodec(lengthFieldLength) + } + private var frameSize = Int.MaxValue + + private def parseLength: Int = intDecoder.decode(buffer.iterator.drop(lengthFieldOffset)) + + private def tryPull(ctx: Context[ByteString]): SyncDirective = { + if (ctx.isFinishing) ctx.fail(new FramingException( + "Stream finished but there was a truncated final frame in the buffer")) + else ctx.pull() + } + + override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { + buffer ++= chunk + doParse(ctx) + } + + override def onPull(ctx: Context[ByteString]): SyncDirective = { + doParse(ctx) + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + if (buffer.nonEmpty) ctx.absorbTermination() + else ctx.finish() + } + + private def emitFrame(ctx: Context[ByteString]): SyncDirective = { + val parsedFrame = buffer.take(frameSize).compact + buffer = buffer.drop(frameSize) + frameSize = Int.MaxValue + if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame) + else ctx.push(parsedFrame) + } + + private def doParse(ctx: Context[ByteString]): SyncDirective = { + if (buffer.size >= frameSize) { + emitFrame(ctx) + } else if (buffer.size >= minimumChunkSize) { + frameSize = parseLength + minimumChunkSize + if (frameSize > maximumFrameLength) + ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength " + + s"but decoded frame header reported size $frameSize")) + else if (buffer.size >= frameSize) + emitFrame(ctx) + else tryPull(ctx) + } else tryPull(ctx) + } + + override def postStop(): Unit = buffer = null + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/BidiFlow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/BidiFlow.scala index 162aa2ba0b..802da41af6 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/BidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/BidiFlow.scala @@ -4,10 +4,7 @@ package akka.stream.javadsl import akka.japi.function -import akka.stream.scaladsl -import akka.stream.Graph -import akka.stream.BidiShape -import akka.stream.OperationAttributes +import akka.stream._ object BidiFlow { @@ -19,6 +16,32 @@ object BidiFlow { */ def wrap[I1, O1, I2, O2, M](g: Graph[BidiShape[I1, O1, I2, O2], M]): BidiFlow[I1, O1, I2, O2, M] = new BidiFlow(scaladsl.BidiFlow.wrap(g)) + /** + * Wraps two Flows to create a ''BidiFlow''. The materialized value of the resulting BidiFlow is determined + * by the combiner function passed in the second argument list. + * + * {{{ + * +----------------------------+ + * | Resulting BidiFlow | + * | | + * | +----------------------+ | + * I1 ~~> | Flow1 | ~~> O1 + * | +----------------------+ | + * | | + * | +----------------------+ | + * O2 <~~ | Flow2 | <~~ I2 + * | +----------------------+ | + * +----------------------------+ + * }}} + * + */ + def wrap[I1, O1, I2, O2, M1, M2, M]( + flow1: Graph[FlowShape[I1, O1], M1], + flow2: Graph[FlowShape[I2, O2], M2], + combine: function.Function2[M1, M2, M]): BidiFlow[I1, O1, I2, O2, M] = { + new BidiFlow(scaladsl.BidiFlow.wrap(flow1, flow2)(combinerToScala(combine))) + } + /** * Create a BidiFlow where the top and bottom flows are just one simple mapping * stage each, expressed by the two functions. diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala index 66d4e3f8da..0fb6dcde3b 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala @@ -132,6 +132,34 @@ object BidiFlow extends BidiFlowApply { */ def wrap[I1, O1, I2, O2, Mat](graph: Graph[BidiShape[I1, O1, I2, O2], Mat]): BidiFlow[I1, O1, I2, O2, Mat] = new BidiFlow(graph.module) + /** + * Wraps two Flows to create a ''BidiFlow''. The materialized value of the resulting BidiFlow is determined + * by the combiner function passed in the second argument list. + * + * {{{ + * +----------------------------+ + * | Resulting BidiFlow | + * | | + * | +----------------------+ | + * I1 ~~> | Flow1 | ~~> O1 + * | +----------------------+ | + * | | + * | +----------------------+ | + * O2 <~~ | Flow2 | <~~ I2 + * | +----------------------+ | + * +----------------------------+ + * }}} + * + */ + def wrap[I1, O1, I2, O2, M1, M2, M]( + flow1: Graph[FlowShape[I1, O1], M1], + flow2: Graph[FlowShape[I2, O2], M2])(combine: (M1, M2) ⇒ M): BidiFlow[I1, O1, I2, O2, M] = { + BidiFlow(flow1, flow2)(combine) { implicit b ⇒ + (f1, f2) ⇒ + BidiShape(f1.inlet, f1.outlet, f2.inlet, f2.outlet) + } + } + /** * Create a BidiFlow where the top and bottom flows are just one simple mapping * stage each, expressed by the two functions.