From 53efa43ab552f9ae69fc3dcdded99bcebb99c710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Endre=20S=C3=A1ndor=20Varga?= Date: Wed, 24 Apr 2013 14:49:03 +0200 Subject: [PATCH] Added delimiter-byte codec that fixes SslTlsSpec #3256 --- .../scala/akka/io/DelimiterFramingSpec.scala | 130 ++++++++++++++++++ .../src/main/scala/akka/io/Pipelines.scala | 84 ++++++++++- .../scala/akka/io/TcpPipelineHandler.scala | 34 +++++ .../scala/akka/io/ssl/SslTlsSupportSpec.scala | 61 ++++---- 4 files changed, 283 insertions(+), 26 deletions(-) create mode 100644 akka-actor-tests/src/test/scala/akka/io/DelimiterFramingSpec.scala diff --git a/akka-actor-tests/src/test/scala/akka/io/DelimiterFramingSpec.scala b/akka-actor-tests/src/test/scala/akka/io/DelimiterFramingSpec.scala new file mode 100644 index 0000000000..b890b0ee84 --- /dev/null +++ b/akka-actor-tests/src/test/scala/akka/io/DelimiterFramingSpec.scala @@ -0,0 +1,130 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ +package akka.io + +import akka.testkit.{ TestProbe, AkkaSpec } +import java.net.InetSocketAddress +import akka.util.ByteString +import akka.actor.{ Props, ActorLogging, Actor, ActorContext } +import akka.TestUtils +import java.util.concurrent.atomic.AtomicInteger + +class DelimiterFramingSpec extends AkkaSpec { + + val addresses = TestUtils.temporaryServerAddresses(4) + + "DelimiterFramingSpec" must { + + "send and receive delimiter based frames correctly (one byte delimiter, exclude)" in { + testSetup(serverAddress = addresses(0), delimiter = "\n", includeDelimiter = false) + } + + "send and receive delimiter based frames correctly (multi-byte delimiter, exclude)" in { + testSetup(serverAddress = addresses(1), delimiter = "DELIMITER", includeDelimiter = false) + } + + "send and receive delimiter based frames correctly (one byte delimiter, include)" in { + testSetup(serverAddress = addresses(2), delimiter = "\n", includeDelimiter = true) + } + + "send and receive delimiter based frames correctly (multi-byte delimiter, include)" in { + testSetup(serverAddress = addresses(3), delimiter = "DELIMITER", includeDelimiter = true) + } + + } + + val counter = new AtomicInteger + + def testSetup(serverAddress: InetSocketAddress, delimiter: String, includeDelimiter: Boolean): Unit = { + val bindHandler = system.actorOf(Props(classOf[AkkaLineEchoServer], this, delimiter, includeDelimiter)) + val probe = TestProbe() + probe.send(IO(Tcp), Tcp.Bind(bindHandler, serverAddress)) + probe.expectMsgType[Tcp.Bound] + + val client = new AkkaLineClient(serverAddress, delimiter, includeDelimiter) + client.run() + client.close() + } + + class AkkaLineClient(address: InetSocketAddress, delimiter: String, includeDelimiter: Boolean) { + + val expectedDelimiter = if (includeDelimiter) delimiter else "" + + val probe = TestProbe() + probe.send(IO(Tcp), Tcp.Connect(address)) + + val connected = probe.expectMsgType[Tcp.Connected] + val connection = probe.sender + + val init = new TcpPipelineHandler.Init( + new StringByteStringAdapter >> + new DelimiterFraming(maxSize = 1024, delimiter = ByteString(delimiter), includeDelimiter = includeDelimiter) >> + new TcpReadWriteAdapter) { + override def makeContext(actorContext: ActorContext): HasLogging = new HasLogging { + override def getLogger = system.log + } + } + + import init._ + + val handler = system.actorOf(TcpPipelineHandler(init, connection, probe.ref), + "client" + counter.incrementAndGet()) + probe.send(connection, Tcp.Register(handler)) + + def run() { + probe.send(handler, Command(s"testone$delimiter")) + probe.expectMsg(Event(s"testone$expectedDelimiter")) + probe.send(handler, Command(s"two${delimiter}thr")) + probe.expectMsg(Event(s"two$expectedDelimiter")) + Thread.sleep(1000) + probe.send(handler, Command(s"ee$delimiter")) + probe.expectMsg(Event(s"three$expectedDelimiter")) + probe.send(handler, Command(s"${delimiter}${delimiter}")) + probe.expectMsg(Event(expectedDelimiter)) + probe.expectMsg(Event(expectedDelimiter)) + } + + def close() { + probe.send(handler, Tcp.Close) + probe.expectMsgType[Tcp.Event] match { + case _: Tcp.ConnectionClosed ⇒ true + } + TestUtils.verifyActorTermination(handler) + } + + } + + class AkkaLineEchoServer(delimiter: String, includeDelimiter: Boolean) extends Actor with ActorLogging { + + import Tcp.Connected + + def receive: Receive = { + case Connected(remote, _) ⇒ + val init = + new TcpPipelineHandler.Init( + new StringByteStringAdapter >> + new DelimiterFraming(maxSize = 1024, delimiter = ByteString(delimiter), includeDelimiter = includeDelimiter) >> + new TcpReadWriteAdapter) { + override def makeContext(actorContext: ActorContext): HasLogging = + new HasLogging { + override def getLogger = log + } + } + import init._ + + val connection = sender + val handler = system.actorOf( + TcpPipelineHandler(init, sender, self), "server" + counter.incrementAndGet()) + + connection ! Tcp.Register(handler) + + context become { + case Event(data) ⇒ + if (includeDelimiter) sender ! Command(data) + else sender ! Command(data + delimiter) + } + } + } + +} diff --git a/akka-actor/src/main/scala/akka/io/Pipelines.scala b/akka-actor/src/main/scala/akka/io/Pipelines.scala index 3219b6bd78..a473ddfb84 100644 --- a/akka-actor/src/main/scala/akka/io/Pipelines.scala +++ b/akka-actor/src/main/scala/akka/io/Pipelines.scala @@ -812,13 +812,95 @@ class LengthFieldFrame(maxSize: Int, frames match { case Nil ⇒ Nil case one :: Nil ⇒ ctx.singleEvent(one) - case many ⇒ many.reverse map (Left(_)) + case many ⇒ many reverseMap (Left(_)) } } } } //#length-field-frame +/** + * Pipeline stage for delimiter byte based framing and de-framing. Useful for string oriented protocol using '\n' + * or 0 as delimiter values. + * + * @param maxSize The maximum size of the frame the pipeline is willing to decode. Not checked for encoding, as the + * sender might decide to pass through multiple chunks in one go (multiple lines in case of a line-based + * protocol) + * @param delimiter The sequence of bytes that will be used as the delimiter for decoding. + * @param includeDelimiter If enabled, the delmiter bytes will be part of the decoded messages. In the case of sends + * the delimiter has to be appended to the end of frames by the user. It is also possible + * to send multiple frames by embedding multiple delimiters in the passed ByteString + */ +class DelimiterFraming(maxSize: Int, delimiter: ByteString = ByteString('\n'), includeDelimiter: Boolean = false) + extends SymmetricPipelineStage[PipelineContext, ByteString, ByteString] { + + require(maxSize > 0, "maxSize must be positive") + require(delimiter.nonEmpty, "delimiter must not be empty") + + override def apply(ctx: PipelineContext) = new SymmetricPipePair[ByteString, ByteString] { + val singleByteDelimiter: Boolean = delimiter.size == 1 + var buffer: ByteString = ByteString.empty + + @tailrec + private def extractParts(nextChunk: ByteString, acc: List[ByteString]): List[ByteString] = { + val firstByteOfDelimiter = delimiter.head + val matchPosition = nextChunk.indexOf(firstByteOfDelimiter) + if (matchPosition == -1) { + val minSize = buffer.size + nextChunk.size + if (minSize > maxSize) throw new IllegalArgumentException( + s"Received too large frame of size $minSize (max = $maxSize)") + buffer ++= nextChunk + acc + } else { + val missingBytes: Int = if (includeDelimiter) matchPosition + delimiter.size else matchPosition + val expectedSize = buffer.size + missingBytes + if (expectedSize > maxSize) throw new IllegalArgumentException( + s"Received frame already of size $expectedSize (max = $maxSize)") + + if (singleByteDelimiter || nextChunk.slice(matchPosition, matchPosition + delimiter.size) == delimiter) { + val decoded = buffer ++ nextChunk.take(missingBytes) + buffer = ByteString.empty + extractParts(nextChunk.drop(matchPosition + delimiter.size), decoded :: acc) + } else { + buffer ++= nextChunk.take(matchPosition + 1) + extractParts(nextChunk.drop(matchPosition + 1), acc) + } + + } + } + + override val eventPipeline = { + bs: ByteString ⇒ + val parts = extractParts(bs, Nil) + parts match { + case Nil ⇒ Nil + case one :: Nil ⇒ ctx.singleEvent(one) + case many ⇒ many reverseMap (Left(_)) + } + } + + override val commandPipeline = { + bs: ByteString ⇒ ctx.singleCommand(bs) + } + } +} + +/** + * Simple convenience pipeline stage for turning Strings into ByteStrings and vice versa. + * + * @param charset The character set to be used for encoding and decoding the raw byte representation of the strings. + */ +class StringByteStringAdapter(charset: String = "utf-8") + extends PipelineStage[PipelineContext, String, ByteString, String, ByteString] { + + override def apply(ctx: PipelineContext) = new PipePair[String, ByteString, String, ByteString] { + + val commandPipeline = (str: String) ⇒ ctx.singleCommand(ByteString(str, charset)) + + val eventPipeline = (bs: ByteString) ⇒ ctx.singleEvent(bs.decodeString(charset)) + } +} + /** * This trait expresses that the pipeline’s context needs to provide a logging * facility. diff --git a/akka-actor/src/main/scala/akka/io/TcpPipelineHandler.scala b/akka-actor/src/main/scala/akka/io/TcpPipelineHandler.scala index 69505ab930..9c881bb1eb 100644 --- a/akka-actor/src/main/scala/akka/io/TcpPipelineHandler.scala +++ b/akka-actor/src/main/scala/akka/io/TcpPipelineHandler.scala @@ -12,9 +12,12 @@ import scala.util.Success import scala.util.Failure import akka.actor.Terminated import akka.actor.Props +import akka.util.ByteString object TcpPipelineHandler { + case class EscapeEvent(ev: Tcp.Event) extends Tcp.Command + /** * This class wraps up a pipeline with its external (i.e. “top”) command and * event types and providing unique wrappers for sending commands and @@ -102,6 +105,7 @@ class TcpPipelineHandler[Ctx <: PipelineContext, Cmd, Evt]( case Tcp.Write(data, Tcp.NoAck) ⇒ connection ! cmd case Tcp.Write(data, ack) ⇒ connection ! Tcp.Write(data, Ack(ack)) case Tell(receiver, msg, sender) ⇒ receiver.tell(msg, sender) + case EscapeEvent(ev) ⇒ handler ! ev case _ ⇒ connection ! cmd } case Failure(ex) ⇒ throw ex @@ -114,6 +118,36 @@ class TcpPipelineHandler[Ctx <: PipelineContext, Cmd, Evt]( case Command(cmd) ⇒ pipes.injectCommand(cmd) case evt: Tcp.Event ⇒ pipes.injectEvent(evt) case Terminated(`handler`) ⇒ connection ! Tcp.Abort + case cmd: Tcp.Command ⇒ pipes.managementCommand(cmd) } } + +/** + * Adapts a ByteString oriented pipeline stage to a stage that communicates via Tcp Commands and Events. Every ByteString + * passed down to this stage will be converted to Tcp.Write commands, while incoming Tcp.Receive events will be unwrapped + * and their contents passed up as raw ByteStrings. This adapter should be used together with TcpPipelineHandler. + * + * While this adapter communicates to the stage above it via raw ByteStrings, it is possible to inject Tcp Command + * by sending them to the management port, and the adapter will simply pass them down to the stage below. Incoming Tcp Events + * that are not Receive events will be passed directly to the handler registered for TcpPipelineHandler. + * @tparam Ctx + */ +class TcpReadWriteAdapter[Ctx <: PipelineContext] extends PipelineStage[Ctx, ByteString, Tcp.Command, ByteString, Tcp.Event] { + + override def apply(ctx: Ctx) = new PipePair[ByteString, Tcp.Command, ByteString, Tcp.Event] { + + override val commandPipeline = { + data: ByteString ⇒ ctx.singleCommand(Tcp.Write(data)) + } + + override val eventPipeline = (evt: Tcp.Event) ⇒ evt match { + case Tcp.Received(data) ⇒ ctx.singleEvent(data) + case ev: Tcp.Event ⇒ ctx.singleCommand(TcpPipelineHandler.EscapeEvent(ev)) + } + + override val managementPort: Mgmt = { + case cmd: Tcp.Command ⇒ ctx.singleCommand(cmd) + } + } +} \ No newline at end of file diff --git a/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala b/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala index c98bdba7cf..ab0dcd6def 100644 --- a/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala +++ b/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala @@ -24,24 +24,20 @@ package akka.io.ssl +import akka.TestUtils +import akka.event.Logging +import akka.event.LoggingAdapter +import akka.io._ +import akka.remote.security.provider.AkkaProvider +import akka.testkit.{ TestProbe, AkkaSpec } +import akka.util.{ ByteString, Timeout } import java.io.{ BufferedWriter, OutputStreamWriter, InputStreamReader, BufferedReader } -import javax.net.ssl._ import java.net.{ InetSocketAddress, SocketException } import java.security.{ KeyStore, SecureRandom } import java.util.concurrent.atomic.AtomicInteger +import javax.net.ssl._ import scala.concurrent.duration._ -import com.typesafe.config.{ ConfigFactory, Config } -import akka.actor._ -import akka.event.LoggingAdapter -import akka.testkit.{ TestProbe, AkkaSpec } -import akka.util.{ ByteString, Timeout } -import akka.io.{ IO, Tcp, PipelineContext } -import akka.TestUtils -import akka.event.Logging -import akka.io.TcpPipelineHandler -import akka.io.SslTlsSupport -import akka.io.HasLogging -import akka.remote.security.provider.AkkaProvider +import akka.actor.{ Props, ActorLogging, Actor, ActorContext } // TODO move this into akka-actor once AkkaProvider for SecureRandom does not have external dependencies class SslTlsSupportSpec extends AkkaSpec { @@ -103,26 +99,38 @@ class SslTlsSupportSpec extends AkkaSpec { val connected = probe.expectMsgType[Tcp.Connected] val connection = probe.sender - val init = new TcpPipelineHandler.Init(new SslTlsSupport(sslEngine(connected.remoteAddress, client = true))) { + val init = new TcpPipelineHandler.Init( + new StringByteStringAdapter >> + new DelimiterFraming(maxSize = 1024, delimiter = ByteString('\n'), includeDelimiter = true) >> + new TcpReadWriteAdapter[HasLogging] >> + new SslTlsSupport(sslEngine(connected.remoteAddress, client = true))) { override def makeContext(actorContext: ActorContext): HasLogging = new HasLogging { override def getLogger = system.log } } + import init._ + val handler = system.actorOf(TcpPipelineHandler(init, connection, probe.ref), "client" + counter.incrementAndGet()) probe.send(connection, Tcp.Register(handler)) def run() { - probe.send(handler, Command(Tcp.Write(ByteString("3+4\n")))) - probe.expectMsg(Event(Tcp.Received(ByteString("7\n")))) - probe.send(handler, Command(Tcp.Write(ByteString("20+22\n")))) - probe.expectMsg(Event(Tcp.Received(ByteString("42\n")))) + probe.send(handler, Command("3+4\n")) + probe.expectMsg(Event("7\n")) + probe.send(handler, Command("20+22\n")) + probe.expectMsg(Event("42\n")) + probe.send(handler, Command("12+24\n11+1")) + Thread.sleep(1000) // Exercise framing by waiting at a mid-frame point + probe.send(handler, Command("1\n0+0\n")) + probe.expectMsg(Event("36\n")) + probe.expectMsg(Event("22\n")) + probe.expectMsg(Event("0\n")) } def close() { - probe.send(handler, Command(Tcp.Close)) - probe.expectMsgType[Event].evt match { + probe.send(handler, Tcp.Close) + probe.expectMsgType[Tcp.Event] match { case _: Tcp.ConnectionClosed ⇒ true } TestUtils.verifyActorTermination(handler) @@ -133,13 +141,16 @@ class SslTlsSupportSpec extends AkkaSpec { //#server class AkkaSslServer extends Actor with ActorLogging { - import Tcp.{ Connected, Received } + import Tcp.Connected def receive: Receive = { case Connected(remote, _) ⇒ val init = new TcpPipelineHandler.Init( - new SslTlsSupport(sslEngine(remote, client = false))) { + new StringByteStringAdapter >> + new DelimiterFraming(maxSize = 1024, delimiter = ByteString('\n'), includeDelimiter = true) >> + new TcpReadWriteAdapter[HasLogging] >> + new SslTlsSupport(sslEngine(remote, client = false))) { override def makeContext(actorContext: ActorContext): HasLogging = new HasLogging { override def getLogger = log @@ -154,11 +165,11 @@ class SslTlsSupportSpec extends AkkaSpec { connection ! Tcp.Register(handler) context become { - case Event(Received(data)) ⇒ - val input = data.utf8String.dropRight(1) + case Event(data) ⇒ + val input = data.dropRight(1) log.debug("akka-io Server received {} from {}", input, sender) val response = serverResponse(input) - sender ! Command(Tcp.Write(ByteString(response))) + sender ! Command(response) log.debug("akka-io Server sent: {}", response.dropRight(1)) } }