diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala index c0edac7986..526f321bd0 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala @@ -4,6 +4,7 @@ package akka.remote.artery import java.util.concurrent.atomic.AtomicInteger + import scala.concurrent.Await import scala.concurrent.duration._ import akka.Done @@ -23,6 +24,8 @@ import akka.actor.ExtendedActorSystem import org.agrona.IoUtil import java.io.File +import akka.util.ByteString + object AeronStreamConsistencySpec extends MultiNodeConfig { val first = role("first") val second = role("second") @@ -65,6 +68,8 @@ abstract class AeronStreamConsistencySpec r } + val pool = new EnvelopeBufferPool(ArteryTransport.MaximumFrameSize, ArteryTransport.MaximumPooledBuffers) + lazy implicit val mat = ActorMaterializer()(system) import system.dispatcher @@ -90,8 +95,8 @@ abstract class AeronStreamConsistencySpec "start echo" in { runOn(second) { // just echo back - Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) - .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner, pool)) + .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner, pool)) } enterBarrier("echo-started") } @@ -104,35 +109,47 @@ abstract class AeronStreamConsistencySpec val killSwitch = KillSwitches.shared("test") val started = TestProbe() val startMsg = "0".getBytes("utf-8") - Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner, pool)) .via(killSwitch.flow) - .runForeach { bytes ⇒ + .runForeach { envelope ⇒ + val bytes = ByteString.fromByteBuffer(envelope.byteBuffer) if (bytes.length == 1 && bytes(0) == startMsg(0)) started.ref ! Done else { val c = count.incrementAndGet() - val x = new String(bytes, "utf-8").toInt + val x = new String(bytes.toArray, "utf-8").toInt if (x != c) { throw new IllegalArgumentException(s"# wrong message $x expected $c") } if (c == totalMessages) done.countDown() } + pool.release(envelope) }.onFailure { case e ⇒ e.printStackTrace } within(10.seconds) { - Source(1 to 100).map(_ ⇒ startMsg) + Source(1 to 100).map { _ ⇒ + val envelope = pool.acquire() + envelope.byteBuffer.put(startMsg) + envelope.byteBuffer.flip() + envelope + } .throttle(1, 200.milliseconds, 1, ThrottleMode.Shaping) - .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner, pool)) started.expectMsg(Done) } Source(1 to totalMessages) .throttle(10000, 1.second, 1000, ThrottleMode.Shaping) - .map { n ⇒ n.toString.getBytes("utf-8") } - .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) + .map { n ⇒ + val envelope = pool.acquire() + envelope.byteBuffer.put(n.toString.getBytes("utf-8")) + envelope.byteBuffer.flip() + envelope + } + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner, pool)) Await.ready(done, 20.seconds) killSwitch.shutdown() diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala index d364a52cae..6ce14f0b3b 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala @@ -7,6 +7,7 @@ import java.util.concurrent.CyclicBarrier import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicLongArray + import scala.concurrent.duration._ import akka.actor._ import akka.remote.testconductor.RoleName @@ -22,11 +23,14 @@ import io.aeron.Aeron import io.aeron.driver.MediaDriver import org.HdrHistogram.Histogram import java.util.concurrent.atomic.AtomicBoolean + import akka.stream.KillSwitches import akka.Done import org.agrona.IoUtil import java.io.File import java.io.File + +import akka.util.ByteString import io.aeron.CncFileDescriptor object AeronStreamLatencySpec extends MultiNodeConfig { @@ -76,6 +80,8 @@ abstract class AeronStreamLatencySpec val driver = MediaDriver.launchEmbedded() + val pool = new EnvelopeBufferPool(ArteryTransport.MaximumFrameSize, ArteryTransport.MaximumPooledBuffers) + val stats = new AeronStat(AeronStat.mapCounters(new File(driver.aeronDirectoryName, CncFileDescriptor.CNC_FILE))) @@ -193,9 +199,10 @@ abstract class AeronStreamLatencySpec val killSwitch = KillSwitches.shared(testName) val started = TestProbe() val startMsg = "0".getBytes("utf-8") - Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner, pool)) .via(killSwitch.flow) - .runForeach { bytes ⇒ + .runForeach { envelope ⇒ + val bytes = ByteString.fromByteBuffer(envelope.byteBuffer) if (bytes.length == 1 && bytes(0) == startMsg(0)) started.ref ! Done else { @@ -209,12 +216,18 @@ abstract class AeronStreamLatencySpec barrier.await() // this is always the last party } } + pool.release(envelope) } within(10.seconds) { - Source(1 to 50).map(_ ⇒ startMsg) + Source(1 to 50).map { _ ⇒ + val envelope = pool.acquire() + envelope.byteBuffer.put(startMsg) + envelope.byteBuffer.flip() + envelope + } .throttle(1, 200.milliseconds, 1, ThrottleMode.Shaping) - .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner, pool)) started.expectMsg(Done) } @@ -226,10 +239,13 @@ abstract class AeronStreamLatencySpec Source(1 to totalMessages) .throttle(messageRate, 1.second, math.max(messageRate / 10, 1), ThrottleMode.Shaping) .map { n ⇒ + val envelope = pool.acquire() + envelope.byteBuffer.put(payload) + envelope.byteBuffer.flip() sendTimes.set(n - 1, System.nanoTime()) - payload + envelope } - .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner, pool)) barrier.await((totalMessages / messageRate) + 10, SECONDS) } @@ -247,8 +263,8 @@ abstract class AeronStreamLatencySpec "start echo" in { runOn(second) { // just echo back - Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) - .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner, pool)) + .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner, pool)) } enterBarrier("echo-started") } diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala index e9356cad08..e4db9dca62 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala @@ -5,6 +5,7 @@ package akka.remote.artery import java.net.InetAddress import java.util.concurrent.Executors + import scala.collection.AbstractIterator import scala.concurrent.Await import scala.concurrent.duration._ @@ -21,6 +22,8 @@ import io.aeron.Aeron import io.aeron.driver.MediaDriver import akka.stream.KillSwitches import java.io.File + +import akka.util.ByteString import io.aeron.CncFileDescriptor import org.agrona.IoUtil @@ -81,6 +84,8 @@ abstract class AeronStreamMaxThroughputSpec val driver = MediaDriver.launchEmbedded() + val pool = new EnvelopeBufferPool(ArteryTransport.MaximumFrameSize, ArteryTransport.MaximumPooledBuffers) + val stats = new AeronStat(AeronStat.mapCounters(new File(driver.aeronDirectoryName, CncFileDescriptor.CNC_FILE))) @@ -168,9 +173,10 @@ abstract class AeronStreamMaxThroughputSpec var count = 0L val done = TestLatch(1) val killSwitch = KillSwitches.shared(testName) - Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner, pool)) .via(killSwitch.flow) - .runForeach { bytes ⇒ + .runForeach { envelope ⇒ + val bytes = ByteString.fromByteBuffer(envelope.byteBuffer) rep.onMessage(1, bytes.length) count += 1 if (count == 1) { @@ -180,6 +186,7 @@ abstract class AeronStreamMaxThroughputSpec done.countDown() killSwitch.shutdown() } + pool.release(envelope) }.onFailure { case e ⇒ e.printStackTrace @@ -198,8 +205,13 @@ abstract class AeronStreamMaxThroughputSpec val payload = ("0" * payloadSize).getBytes("utf-8") val t0 = System.nanoTime() Source.fromIterator(() ⇒ iterate(1, totalMessages)) - .map { n ⇒ payload } - .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) + .map { n ⇒ + val envelope = pool.acquire() + envelope.byteBuffer.put(payload) + envelope.byteBuffer.flip() + envelope + } + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner, pool)) printStats("sender") enterBarrier(testName + "-done") diff --git a/akka-remote/src/main/scala/akka/remote/MessageSerializer.scala b/akka-remote/src/main/scala/akka/remote/MessageSerializer.scala index e222adacd5..e11627a078 100644 --- a/akka-remote/src/main/scala/akka/remote/MessageSerializer.scala +++ b/akka-remote/src/main/scala/akka/remote/MessageSerializer.scala @@ -7,8 +7,8 @@ package akka.remote import akka.remote.WireFormats._ import akka.protobuf.ByteString import akka.actor.ExtendedActorSystem -import akka.serialization.SerializationExtension -import akka.serialization.SerializerWithStringManifest +import akka.remote.artery.{ EnvelopeBuffer, HeaderBuilder } +import akka.serialization.{ Serialization, SerializationExtension, SerializerWithStringManifest } /** * INTERNAL API @@ -47,4 +47,33 @@ private[akka] object MessageSerializer { } builder.build } + + def serializeForArtery(serialization: Serialization, message: AnyRef, headerBuilder: HeaderBuilder, envelope: EnvelopeBuffer): Unit = { + val serializer = serialization.findSerializerFor(message) + + // FIXME: This should be a FQCN instead + headerBuilder.serializer = serializer.identifier.toString + serializer match { + case ser2: SerializerWithStringManifest ⇒ + val manifest = ser2.manifest(message) + headerBuilder.classManifest = manifest + case _ ⇒ + headerBuilder.classManifest = message.getClass.getName + } + + envelope.writeHeader(headerBuilder) + // FIXME: This should directly write to the buffer instead + envelope.byteBuffer.put(serializer.toBinary(message)) + } + + def deserializeForArtery(system: ExtendedActorSystem, headerBuilder: HeaderBuilder, envelope: EnvelopeBuffer): AnyRef = { + // FIXME: Use the buffer directly + val size = envelope.byteBuffer.limit - envelope.byteBuffer.position + val bytes = Array.ofDim[Byte](size) + envelope.byteBuffer.get(bytes) + SerializationExtension(system).deserialize( + bytes, + Integer.parseInt(headerBuilder.serializer), // FIXME: Use FQCN + headerBuilder.classManifest).get + } } diff --git a/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala b/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala index 9807591bc1..611809a449 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala @@ -27,48 +27,49 @@ import io.aeron.Publication import org.agrona.concurrent.UnsafeBuffer object AeronSink { - type Bytes = Array[Byte] - private def offerTask(pub: Publication, buffer: UnsafeBuffer, msgSize: AtomicInteger, onOfferSuccess: AsyncCallback[Unit]): () ⇒ Boolean = { + class OfferTask(pub: Publication, var buffer: UnsafeBuffer, msgSize: AtomicInteger, onOfferSuccess: AsyncCallback[Unit]) + extends (() ⇒ Boolean) { + var n = 0L var localMsgSize = -1 - () ⇒ - { - n += 1 - if (localMsgSize == -1) - localMsgSize = msgSize.get - val result = pub.offer(buffer, 0, localMsgSize) - if (result >= 0) { - n = 0 - localMsgSize = -1 - onOfferSuccess.invoke(()) - true - } else { - // FIXME drop after too many attempts? - if (n > 1000000 && n % 100000 == 0) - println(s"# offer not accepted after $n") // FIXME - false - } + + override def apply(): Boolean = { + n += 1 + if (localMsgSize == -1) + localMsgSize = msgSize.get + val result = pub.offer(buffer, 0, localMsgSize) + if (result >= 0) { + n = 0 + localMsgSize = -1 + onOfferSuccess.invoke(()) + true + } else { + // FIXME drop after too many attempts? + if (n > 1000000 && n % 100000 == 0) + println(s"# offer not accepted after $n") // FIXME + false } + } } } /** * @param channel eg. "aeron:udp?endpoint=localhost:40123" */ -class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner) - extends GraphStageWithMaterializedValue[SinkShape[AeronSink.Bytes], Future[Done]] { +class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner, pool: EnvelopeBufferPool) + extends GraphStageWithMaterializedValue[SinkShape[EnvelopeBuffer], Future[Done]] { import AeronSink._ import TaskRunner._ - val in: Inlet[Bytes] = Inlet("AeronSink") - override val shape: SinkShape[Bytes] = SinkShape(in) + val in: Inlet[EnvelopeBuffer] = Inlet("AeronSink") + override val shape: SinkShape[EnvelopeBuffer] = SinkShape(in) override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Done]) = { val completed = Promise[Done]() val logic = new GraphStageLogic(shape) with InHandler { - private val buffer = new UnsafeBuffer(ByteBuffer.allocateDirect(128 * 1024)) + private var envelopeInFlight: EnvelopeBuffer = null private val pub = aeron.addPublication(channel, streamId) private var completedValue: Try[Done] = Success(Done) @@ -76,8 +77,9 @@ class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRu private val spinning = 1000 private var backoffCount = spinning private var lastMsgSize = 0 - private var lastMsgSizeRef = new AtomicInteger // used in the external backoff task - private val addOfferTask: Add = Add(offerTask(pub, buffer, lastMsgSizeRef, getAsyncCallback(_ ⇒ onOfferSuccess()))) + private val lastMsgSizeRef = new AtomicInteger // used in the external backoff task + private val offerTask = new OfferTask(pub, null, lastMsgSizeRef, getAsyncCallback(_ ⇒ onOfferSuccess())) + private val addOfferTask: Add = Add(offerTask) private var offerTaskInProgress = false @@ -94,15 +96,14 @@ class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRu // InHandler override def onPush(): Unit = { - val msg = grab(in) - buffer.putBytes(0, msg); + envelopeInFlight = grab(in) backoffCount = spinning - lastMsgSize = msg.length + lastMsgSize = envelopeInFlight.byteBuffer.limit publish() } @tailrec private def publish(): Unit = { - val result = pub.offer(buffer, 0, lastMsgSize) + val result = pub.offer(envelopeInFlight.aeronBuffer, 0, lastMsgSize) // FIXME handle Publication.CLOSED // TODO the backoff strategy should be measured and tuned if (result < 0) { @@ -113,6 +114,7 @@ class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRu // delegate backoff to shared TaskRunner lastMsgSizeRef.set(lastMsgSize) offerTaskInProgress = true + offerTask.buffer = envelopeInFlight.aeronBuffer taskRunner.command(addOfferTask) } } else { @@ -122,6 +124,10 @@ class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRu private def onOfferSuccess(): Unit = { offerTaskInProgress = false + pool.release(envelopeInFlight) + offerTask.buffer = null + envelopeInFlight = null + if (isClosed(in)) completeStage() else diff --git a/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala b/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala index c03aa69dea..48af59f14d 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala @@ -24,9 +24,8 @@ import org.agrona.DirectBuffer import org.agrona.concurrent.BackoffIdleStrategy object AeronSource { - type Bytes = Array[Byte] - private def pollTask(sub: Subscription, handler: MessageHandler, onMessage: AsyncCallback[Bytes]): () ⇒ Boolean = { + private def pollTask(sub: Subscription, handler: MessageHandler, onMessage: AsyncCallback[EnvelopeBuffer]): () ⇒ Boolean = { () ⇒ { handler.reset @@ -41,19 +40,20 @@ object AeronSource { } } - class MessageHandler { + class MessageHandler(pool: EnvelopeBufferPool) { def reset(): Unit = messageReceived = null - var messageReceived: Bytes = null + var messageReceived: EnvelopeBuffer = null - val fragmentsHandler = new Fragments(data ⇒ messageReceived = data) + val fragmentsHandler = new Fragments(data ⇒ messageReceived = data, pool) } - class Fragments(onMessage: Bytes ⇒ Unit) extends FragmentAssembler(new FragmentHandler { - override def onFragment(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Unit = { - val data = Array.ofDim[Byte](length) - buffer.getBytes(offset, data) - onMessage(data) + class Fragments(onMessage: EnvelopeBuffer ⇒ Unit, pool: EnvelopeBufferPool) extends FragmentAssembler(new FragmentHandler { + override def onFragment(aeronBuffer: DirectBuffer, offset: Int, length: Int, header: Header): Unit = { + val envelope = pool.acquire() + aeronBuffer.getBytes(offset, envelope.byteBuffer, length) + envelope.byteBuffer.flip() + onMessage(envelope) } }) } @@ -61,12 +61,13 @@ object AeronSource { /** * @param channel eg. "aeron:udp?endpoint=localhost:40123" */ -class AeronSource(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner) extends GraphStage[SourceShape[AeronSource.Bytes]] { +class AeronSource(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner, pool: EnvelopeBufferPool) + extends GraphStage[SourceShape[EnvelopeBuffer]] { import AeronSource._ import TaskRunner._ - val out: Outlet[Bytes] = Outlet("AeronSource") - override val shape: SourceShape[Bytes] = SourceShape(out) + val out: Outlet[EnvelopeBuffer] = Outlet("AeronSource") + override val shape: SourceShape[EnvelopeBuffer] = SourceShape(out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { @@ -81,7 +82,7 @@ class AeronSource(channel: String, streamId: Int, aeron: Aeron, taskRunner: Task private var backoffCount = idleStrategyRetries // the fragmentHandler is called from `poll` in same thread, i.e. no async callback is needed - private val messageHandler = new MessageHandler + private val messageHandler = new MessageHandler(pool) private val addPollTask: Add = Add(pollTask(sub, messageHandler, getAsyncCallback(onMessage))) override def postStop(): Unit = { @@ -119,7 +120,7 @@ class AeronSource(channel: String, streamId: Int, aeron: Aeron, taskRunner: Task } } - private def onMessage(data: Bytes): Unit = { + private def onMessage(data: EnvelopeBuffer): Unit = { push(out, data) } diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala index 0ef73d7d47..05aa1aa075 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala @@ -246,6 +246,11 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R private val maxRestarts = 5 // FIXME config private val restartCounter = new RestartCounter(maxRestarts, restartTimeout) + val envelopePool = new EnvelopeBufferPool(ArteryTransport.MaximumFrameSize, ArteryTransport.MaximumPooledBuffers) + // FIXME: Compression table must be owned by each channel instead + // of having a global one + val compression = new Compression(system) + override def start(): Unit = { startMediaDriver() startAeron() @@ -318,9 +323,8 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } private def runInboundControlStream(): Unit = { - val (c, completed) = Source.fromGraph(new AeronSource(inboundChannel, controlStreamId, aeron, taskRunner)) + val (c, completed) = Source.fromGraph(new AeronSource(inboundChannel, controlStreamId, aeron, taskRunner, envelopePool)) .async // FIXME measure - .map(ByteString.apply) // TODO we should use ByteString all the way .viaMat(inboundControlFlow)(Keep.right) .toMat(Sink.ignore)(Keep.both) .run()(materializer) @@ -357,9 +361,8 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } private def runInboundOrdinaryMessagesStream(): Unit = { - val completed = Source.fromGraph(new AeronSource(inboundChannel, ordinaryStreamId, aeron, taskRunner)) + val completed = Source.fromGraph(new AeronSource(inboundChannel, ordinaryStreamId, aeron, taskRunner, envelopePool)) .async // FIXME measure - .map(ByteString.apply) // TODO we should use ByteString all the way .via(inboundFlow) .runWith(Sink.ignore)(materializer) @@ -437,8 +440,7 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R Flow.fromGraph(killSwitch.flow[Send]) .via(new OutboundHandshake(outboundContext, handshakeTimeout, handshakeRetryInterval)) .via(encoder) - .map(_.toArray) // TODO we should use ByteString all the way - .toMat(new AeronSink(outboundChannel(outboundContext.remoteAddress), ordinaryStreamId, aeron, taskRunner))(Keep.right) + .toMat(new AeronSink(outboundChannel(outboundContext.remoteAddress), ordinaryStreamId, aeron, taskRunner, envelopePool))(Keep.right) } def outboundControl(outboundContext: OutboundContext): Sink[Send, (OutboundControlIngress, Future[Done])] = { @@ -447,10 +449,7 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R .via(new SystemMessageDelivery(outboundContext, systemMessageResendInterval, remoteSettings.SysMsgBufferSize)) .viaMat(new OutboundControlJunction(outboundContext))(Keep.right) .via(encoder) - .map(_.toArray) // TODO we should use ByteString all the way - .toMat(new AeronSink(outboundChannel(outboundContext.remoteAddress), controlStreamId, aeron, taskRunner))(Keep.both) - - // FIXME we can also add scrubbing stage that would collapse sys msg acks/nacks and remove duplicate Quarantine messages + .toMat(new AeronSink(outboundChannel(outboundContext.remoteAddress), controlStreamId, aeron, taskRunner, envelopePool))(Keep.both) // FIXME we can also add scrubbing stage that would collapse sys msg acks/nacks and remove duplicate Quarantine messages } @@ -458,59 +457,27 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R // FIXME hack until real envelopes, encoding originAddress in sender :) private val dummySender = system.systemActorOf(Props.empty, "dummy") - // TODO: Try out parallelized serialization (mapAsync) for performance - val encoder: Flow[Send, ByteString, NotUsed] = Flow[Send].map { sendEnvelope ⇒ - val pdu: ByteString = codec.constructMessage( - sendEnvelope.recipient.localAddressToUse, - sendEnvelope.recipient, - Serialization.currentTransportInformation.withValue(Serialization.Information(localAddress.address, system)) { - MessageSerializer.serialize(system, sendEnvelope.message.asInstanceOf[AnyRef]) - }, - if (sendEnvelope.senderOption.isDefined) sendEnvelope.senderOption else Some(dummySender), // FIXME: hack until real envelopes - seqOption = Some(SeqNo(localAddress.uid)), // FIXME: hack until real envelopes - ackOption = None) - - // TODO: Drop unserializable messages - // TODO: Drop oversized messages - (new ByteStringBuilder).putInt(pdu.size)(ByteOrder.LITTLE_ENDIAN).result() ++ pdu - } - - val decoder: Flow[ByteString, AkkaPduCodec.Message, NotUsed] = - Framing.lengthField(4, maximumFrameLength = 256000) - .map { frame ⇒ - // TODO: Drop unserializable messages - val pdu = codec.decodeMessage(frame.drop(4), provider, localAddress.address)._2.get - pdu - } + val encoder: Flow[Send, EnvelopeBuffer, NotUsed] = + Flow.fromGraph(new Encoder(this, compression)) val messageDispatcherSink: Sink[InboundEnvelope, Future[Done]] = Sink.foreach[InboundEnvelope] { m ⇒ messageDispatcher.dispatch(m.recipient, m.recipientAddress, m.message, m.senderOption) } - val deserializer: Flow[AkkaPduCodec.Message, InboundEnvelope, NotUsed] = - Flow[AkkaPduCodec.Message].map { m ⇒ - InboundEnvelope( - m.recipient, - m.recipientAddress, - MessageSerializer.deserialize(system, m.serializedMessage), - if (m.senderOption.get.path.name == "dummy") None else m.senderOption, // FIXME hack until real envelopes - UniqueAddress(m.senderOption.get.path.address, m.seq.rawValue.toInt)) // FIXME hack until real envelopes - } + val decoder = Flow.fromGraph(new Decoder(this, compression)) - val inboundFlow: Flow[ByteString, ByteString, NotUsed] = { + val inboundFlow: Flow[EnvelopeBuffer, ByteString, NotUsed] = { Flow.fromSinkAndSource( decoder - .via(deserializer) .via(new InboundHandshake(this, inControlStream = false)) .via(new InboundQuarantineCheck(this)) .to(messageDispatcherSink), Source.maybe[ByteString].via(killSwitch.flow)) } - val inboundControlFlow: Flow[ByteString, ByteString, ControlMessageSubject] = { + val inboundControlFlow: Flow[EnvelopeBuffer, ByteString, ControlMessageSubject] = { Flow.fromSinkAndSourceMat( decoder - .via(deserializer) .via(new InboundHandshake(this, inControlStream = true)) .via(new InboundQuarantineCheck(this)) .viaMat(new InboundControlJunction)(Keep.right) @@ -521,10 +488,18 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } -object ArteryTransport { +/** + * INTERNAL API + */ +private[remote] object ArteryTransport { + + val Version = 0 + val MaximumFrameSize = 1024 * 1024 + val MaximumPooledBuffers = 256 /** * Internal API + * * @return A port that is hopefully available */ private[remote] def autoSelectPort(hostname: String): Int = { diff --git a/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala b/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala new file mode 100644 index 0000000000..901391673d --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala @@ -0,0 +1,316 @@ +/** + * Copyright (C) 2009-2016 Lightbend Inc. + */ + +package akka.remote.artery + +import java.lang.reflect.Field +import java.nio.charset.Charset +import java.nio.{ ByteBuffer, ByteOrder } + +import org.agrona.concurrent.{ ManyToManyConcurrentArrayQueue, UnsafeBuffer } +import sun.misc.Cleaner + +import scala.util.control.NonFatal + +/** + * INTERNAL API + */ +private[remote] class OutOfBuffersException extends RuntimeException("Out of usable ByteBuffers") + +/** + * INTERNAL API + */ +private[remote] class EnvelopeBufferPool(maximumPayload: Int, maximumBuffers: Int) { + private val availableBuffers = new ManyToManyConcurrentArrayQueue[EnvelopeBuffer](maximumBuffers) + + def acquire(): EnvelopeBuffer = { + val buf = availableBuffers.poll() + if (buf ne null) { + buf.byteBuffer.clear() + buf + } else { + val newBuf = new EnvelopeBuffer(ByteBuffer.allocateDirect(maximumPayload)) + newBuf.byteBuffer.order(ByteOrder.LITTLE_ENDIAN) + newBuf + } + } + + def release(buffer: EnvelopeBuffer) = if (!availableBuffers.offer(buffer)) buffer.tryForceDrop() + +} + +/** + * INTERNAL API + */ +private[remote] object EnvelopeBuffer { + + val TagTypeMask = 0xFF000000 + val TagValueMask = 0x0000FFFF + + val VersionOffset = 0 + val UidOffset = 4 + val SenderActorRefTagOffset = 8 + val RecipientActorRefTagOffset = 12 + val SerializerTagOffset = 16 + val ClassManifestTagOffset = 24 + + val LiteralsSectionOffset = 32 + + val UsAscii = Charset.forName("US-ASCII") + + val DeadLettersCode = 0 +} + +/** + * INTERNAL API + */ +private[remote] trait LiteralCompressionTable { + + def compressActorRef(ref: String): Int + def decompressActorRef(idx: Int): String + + def compressSerializer(serializer: String): Int + def decompressSerializer(idx: Int): String + + def compressClassManifest(manifest: String): Int + def decompressClassManifest(idx: Int): String + +} + +object HeaderBuilder { + def apply(compressionTable: LiteralCompressionTable): HeaderBuilder = new HeaderBuilderImpl(compressionTable) +} + +/** + * INTERNAL API + */ +sealed trait HeaderBuilder { + def version_=(v: Int): Unit + def version: Int + + def uid_=(u: Int): Unit + def uid: Int + + def senderActorRef_=(ref: String): Unit + def senderActorRef: String + + def setNoSender(): Unit + + def recipientActorRef_=(ref: String): Unit + def recipientActorRef: String + + def serializer_=(serializer: String): Unit + def serializer: String + + def classManifest_=(manifest: String): Unit + def classManifest: String +} + +/** + * INTERNAL API + */ +private[remote] final class HeaderBuilderImpl(val compressionTable: LiteralCompressionTable) extends HeaderBuilder { + var version: Int = _ + var uid: Int = _ + + // Fields only available for EnvelopeBuffer + var _senderActorRef: String = null + var _senderActorRefIdx: Int = -1 + var _recipientActorRef: String = null + var _recipientActorRefIdx: Int = -1 + + var _serializer: String = null + var _serializerIdx: Int = -1 + var _classManifest: String = null + var _classManifestIdx: Int = -1 + + def senderActorRef_=(ref: String): Unit = { + _senderActorRef = ref + _senderActorRefIdx = compressionTable.compressActorRef(ref) + } + + def setNoSender(): Unit = { + _senderActorRef = null + _senderActorRefIdx = EnvelopeBuffer.DeadLettersCode + } + + def senderActorRef: String = { + if (_senderActorRef ne null) _senderActorRef + else { + _senderActorRef = compressionTable.decompressActorRef(_senderActorRefIdx) + _senderActorRef + } + } + + def recipientActorRef_=(ref: String): Unit = { + _recipientActorRef = ref + _recipientActorRefIdx = compressionTable.compressActorRef(ref) + } + + def recipientActorRef: String = { + if (_recipientActorRef ne null) _recipientActorRef + else { + _recipientActorRef = compressionTable.decompressActorRef(_recipientActorRefIdx) + _recipientActorRef + } + } + + override def serializer_=(serializer: String): Unit = { + _serializer = serializer + _serializerIdx = compressionTable.compressSerializer(serializer) + } + + override def serializer: String = { + if (_serializer ne null) _serializer + else { + _serializer = compressionTable.decompressSerializer(_serializerIdx) + _serializer + } + } + + override def classManifest_=(manifest: String): Unit = { + _classManifest = manifest + _classManifestIdx = compressionTable.compressClassManifest(manifest) + } + + override def classManifest: String = { + if (_classManifest ne null) _classManifest + else { + _classManifest = compressionTable.decompressClassManifest(_classManifestIdx) + _classManifest + } + } + + override def toString = s"HeaderBuilderImpl($version, $uid, ${_senderActorRef}, ${_senderActorRefIdx}, ${_recipientActorRef}, ${_recipientActorRefIdx}, ${_serializer}, ${_serializerIdx}, ${_classManifest}, ${_classManifestIdx})" +} + +/** + * INTERNAL API + */ +private[remote] final class EnvelopeBuffer(val byteBuffer: ByteBuffer) { + import EnvelopeBuffer._ + val aeronBuffer = new UnsafeBuffer(byteBuffer) + + private val cleanerField: Field = try { + val cleaner = byteBuffer.getClass.getDeclaredField("cleaner") + cleaner.setAccessible(true) + cleaner + } catch { + case NonFatal(_) ⇒ null + } + + def tryForceDrop(): Unit = { + if (cleanerField ne null) cleanerField.get(byteBuffer) match { + case cleaner: Cleaner ⇒ cleaner.clean() + case _ ⇒ + } + } + + def writeHeader(h: HeaderBuilder): Unit = { + val header = h.asInstanceOf[HeaderBuilderImpl] + byteBuffer.clear() + + // Write fixed length parts + byteBuffer.putInt(header.version) + byteBuffer.putInt(header.uid) + + // Write compressable, variable-length parts always to the actual position of the buffer + // Write tag values explicitly in their proper offset + byteBuffer.position(LiteralsSectionOffset) + + // Serialize sender + if (header._senderActorRefIdx != -1) + byteBuffer.putInt(SenderActorRefTagOffset, header._senderActorRefIdx | TagTypeMask) + else + writeLiteral(SenderActorRefTagOffset, header._senderActorRef) + + // Serialize recipient + if (header._recipientActorRefIdx != -1) + byteBuffer.putInt(RecipientActorRefTagOffset, header._recipientActorRefIdx | TagTypeMask) + else + writeLiteral(RecipientActorRefTagOffset, header._recipientActorRef) + + // Serialize serializer + if (header._serializerIdx != -1) + byteBuffer.putInt(SerializerTagOffset, header._serializerIdx | TagTypeMask) + else + writeLiteral(SerializerTagOffset, header._serializer) + + // Serialize class manifest + if (header._classManifestIdx != -1) + byteBuffer.putInt(ClassManifestTagOffset, header._classManifestIdx | TagTypeMask) + else + writeLiteral(ClassManifestTagOffset, header._classManifest) + } + + def parseHeader(h: HeaderBuilder): Unit = { + val header = h.asInstanceOf[HeaderBuilderImpl] + + // Read fixed length parts + header.version = byteBuffer.getInt + header.uid = byteBuffer.getInt + + // Read compressable, variable-length parts always from the actual position of the buffer + // Read tag values explicitly from their proper offset + byteBuffer.position(LiteralsSectionOffset) + + // Deserialize sender + val senderTag = byteBuffer.getInt(SenderActorRefTagOffset) + if ((senderTag & TagTypeMask) != 0) { + val idx = senderTag & TagValueMask + header._senderActorRef = null + header._senderActorRefIdx = idx + } else { + header._senderActorRef = readLiteral() + } + + // Deserialize recipient + val recipientTag = byteBuffer.getInt(RecipientActorRefTagOffset) + if ((recipientTag & TagTypeMask) != 0) { + val idx = recipientTag & TagValueMask + header._recipientActorRef = null + header._recipientActorRefIdx = idx + } else { + header._recipientActorRef = readLiteral() + } + + // Deserialize serializer + val serializerTag = byteBuffer.getInt(SerializerTagOffset) + if ((serializerTag & TagTypeMask) != 0) { + val idx = serializerTag & TagValueMask + header._serializer = null + header._serializerIdx = idx + } else { + header._serializer = readLiteral() + } + + // Deserialize class manifest + val classManifestTag = byteBuffer.getInt(ClassManifestTagOffset) + if ((classManifestTag & TagTypeMask) != 0) { + val idx = classManifestTag & TagValueMask + header._classManifest = null + header._classManifestIdx = idx + } else { + header._classManifest = readLiteral() + } + } + + private def readLiteral(): String = { + val length = byteBuffer.getShort + val bytes = Array.ofDim[Byte](length) + byteBuffer.get(bytes) + new String(bytes, UsAscii) + } + + private def writeLiteral(tagOffset: Int, literal: String): Unit = { + if (literal.length > 65535) + throw new IllegalArgumentException("Literals longer than 65535 cannot be encoded in the envelope") + + val literalBytes = literal.getBytes(UsAscii) + byteBuffer.putInt(tagOffset, byteBuffer.position()) + byteBuffer.putShort(literalBytes.length.toShort) + byteBuffer.put(literalBytes) + } + +} \ No newline at end of file diff --git a/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala b/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala new file mode 100644 index 0000000000..db412d7862 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala @@ -0,0 +1,103 @@ +package akka.remote.artery + +import akka.actor.{ ActorRef, InternalActorRef } +import akka.remote.EndpointManager.Send +import akka.remote.{ MessageSerializer, UniqueAddress } +import akka.serialization.{ Serialization, SerializationExtension } +import akka.stream._ +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } + +// TODO: Long UID +class Encoder( + transport: ArteryTransport, + compressionTable: LiteralCompressionTable) + extends GraphStage[FlowShape[Send, EnvelopeBuffer]] { + + val in: Inlet[Send] = Inlet("Artery.Encoder.in") + val out: Outlet[EnvelopeBuffer] = Outlet("Artery.Encoder.out") + val shape: FlowShape[Send, EnvelopeBuffer] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + + private val pool = transport.envelopePool + private val headerBuilder = HeaderBuilder(compressionTable) + headerBuilder.version = ArteryTransport.Version + headerBuilder.uid = transport.localAddress.uid + private val localAddress = transport.localAddress.address + private val serialization = SerializationExtension(transport.system) + + override def onPush(): Unit = { + val send = grab(in) + val envelope = pool.acquire() + + headerBuilder.recipientActorRef = send.recipient.path.toSerializationFormat + send.senderOption match { + case Some(sender) ⇒ + headerBuilder.senderActorRef = sender.path.toSerializationFormatWithAddress(localAddress) + case None ⇒ + //headerBuilder.setNoSender() + headerBuilder.senderActorRef = transport.system.deadLetters.path.toSerializationFormatWithAddress(localAddress) + } + + // FIXME: Thunk allocation + Serialization.currentTransportInformation.withValue(Serialization.Information(localAddress, transport.system)) { + MessageSerializer.serializeForArtery(serialization, send.message.asInstanceOf[AnyRef], headerBuilder, envelope) + } + + //println(s"${headerBuilder.senderActorRef} --> ${headerBuilder.recipientActorRef} ${headerBuilder.classManifest}") + + envelope.byteBuffer.flip() + push(out, envelope) + } + + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } +} + +class Decoder( + transport: ArteryTransport, + compressionTable: LiteralCompressionTable) extends GraphStage[FlowShape[EnvelopeBuffer, InboundEnvelope]] { + val in: Inlet[EnvelopeBuffer] = Inlet("Artery.Decoder.in") + val out: Outlet[InboundEnvelope] = Outlet("Artery.Decoder.out") + val shape: FlowShape[EnvelopeBuffer, InboundEnvelope] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private val pool = transport.envelopePool + private val localAddress = transport.localAddress.address + private val provider = transport.provider + private val headerBuilder = HeaderBuilder(compressionTable) + + override def onPush(): Unit = { + val envelope = grab(in) + envelope.parseHeader(headerBuilder) + + //println(s"${headerBuilder.recipientActorRef} <-- ${headerBuilder.senderActorRef} ${headerBuilder.classManifest}") + + // FIXME: Instead of using Strings, the headerBuilder should automatically return cached ActorRef instances + // in case of compression is enabled + // FIXME: Is localAddress really needed? + val recipient: InternalActorRef = + provider.resolveActorRefWithLocalAddress(headerBuilder.recipientActorRef, localAddress) + val sender: ActorRef = + provider.resolveActorRefWithLocalAddress(headerBuilder.senderActorRef, localAddress) + + val decoded = InboundEnvelope( + recipient, + localAddress, // FIXME: Is this needed anymore? What should we do here? + MessageSerializer.deserializeForArtery(transport.system, headerBuilder, envelope), + Some(sender), // FIXME: No need for an option, decode simply to deadLetters instead + UniqueAddress(sender.path.address, headerBuilder.uid)) + + pool.release(envelope) + push(out, decoded) + } + + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/artery/Compression.scala b/akka-remote/src/main/scala/akka/remote/artery/Compression.scala new file mode 100644 index 0000000000..364a7a5c9d --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/Compression.scala @@ -0,0 +1,23 @@ +package akka.remote.artery + +import akka.actor.ActorSystem + +/** + * INTERNAL API + */ +// FIXME: Dummy compression table, needs to be replaced by the real deal +// Currently disables all compression +private[remote] class Compression(system: ActorSystem) extends LiteralCompressionTable { + // FIXME: Of course it is foolish to store this as String, but this is a stub + val deadLettersString = system.deadLetters.path.toSerializationFormat + + override def compressActorRef(ref: String): Int = -1 + override def decompressActorRef(idx: Int): String = ??? + + override def compressSerializer(serializer: String): Int = -1 + override def decompressSerializer(idx: Int): String = ??? + + override def compressClassManifest(manifest: String): Int = -1 + override def decompressClassManifest(idx: Int): String = ??? + +} diff --git a/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala b/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala index 834e124a5a..26fcdb904d 100644 --- a/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala +++ b/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala @@ -91,6 +91,8 @@ object AeronStreamsApp { }) lazy val reporterExecutor = Executors.newFixedThreadPool(1) + lazy val pool = new EnvelopeBufferPool(ArteryTransport.MaximumFrameSize, ArteryTransport.MaximumFrameSize) + def stopReporter(): Unit = { reporter.halt() reporterExecutor.shutdown() @@ -140,10 +142,10 @@ object AeronStreamsApp { if (args.length != 0 && args(0) == "echo-receiver") runEchoReceiver() - if (args(0) == "debug-receiver") + if (args.length != 0 && args(0) == "debug-receiver") runDebugReceiver() - if (args(0) == "debug-sender") + if (args.length != 0 && args(0) == "debug-sender") runDebugSender() if (args.length >= 2 && args(1) == "stats") @@ -157,20 +159,21 @@ object AeronStreamsApp { var t0 = System.nanoTime() var count = 0L var payloadSize = 0L - Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) - .map { bytes ⇒ - r.onMessage(1, bytes.length) - bytes + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner, pool)) + .map { envelope ⇒ + r.onMessage(1, envelope.byteBuffer.limit) + envelope } - .runForeach { bytes ⇒ + .runForeach { envelope ⇒ count += 1 if (count == 1) { t0 = System.nanoTime() - payloadSize = bytes.length + payloadSize = envelope.byteBuffer.limit } else if (count == throughputN) { exit(0) printTotal(throughputN, "receive", t0, payloadSize) } + pool.release(envelope) }.onFailure { case e ⇒ e.printStackTrace @@ -193,21 +196,24 @@ object AeronStreamsApp { } .map { _ ⇒ r.onMessage(1, payload.length) - payload + val envelope = pool.acquire() + envelope.byteBuffer.put(payload) + envelope.byteBuffer.flip() + envelope } - .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner, pool)) } def runEchoReceiver(): Unit = { // just echo back on channel2 reporterExecutor.execute(reporter) val r = reporter - Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) - .map { bytes ⇒ - r.onMessage(1, bytes.length) - bytes + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner, pool)) + .map { envelope ⇒ + r.onMessage(1, envelope.byteBuffer.limit) + envelope } - .runWith(new AeronSink(channel2, streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel2, streamId, aeron, taskRunner, pool)) } def runEchoSender(): Unit = { @@ -219,21 +225,22 @@ object AeronStreamsApp { var repeat = 3 val count = new AtomicInteger var t0 = System.nanoTime() - Source.fromGraph(new AeronSource(channel2, streamId, aeron, taskRunner)) - .map { bytes ⇒ - r.onMessage(1, bytes.length) - bytes + Source.fromGraph(new AeronSource(channel2, streamId, aeron, taskRunner, pool)) + .map { envelope ⇒ + r.onMessage(1, envelope.byteBuffer.limit) + envelope } - .runForeach { bytes ⇒ + .runForeach { envelope ⇒ val c = count.incrementAndGet() val d = System.nanoTime() - sendTimes.get(c - 1) if (c % (latencyN / 10) == 0) println(s"# receive offset $c => ${d / 1000} µs") // FIXME histogram.recordValue(d) if (c == latencyN) { - printTotal(latencyN, "ping-pong", t0, bytes.length) + printTotal(latencyN, "ping-pong", t0, envelope.byteBuffer.limit) barrier.await() // this is always the last party } + pool.release(envelope) }.onFailure { case e ⇒ e.printStackTrace @@ -252,9 +259,12 @@ object AeronStreamsApp { if (n % (latencyN / 10) == 0) println(s"# send offset $n") // FIXME sendTimes.set(n - 1, System.nanoTime()) - payload + val envelope = pool.acquire() + envelope.byteBuffer.put(payload) + envelope.byteBuffer.flip() + envelope } - .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner, pool)) barrier.await() } @@ -264,8 +274,13 @@ object AeronStreamsApp { def runDebugReceiver(): Unit = { import system.dispatcher - Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) - .map(bytes ⇒ new String(bytes, "utf-8")) + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner, pool)) + .map { envelope ⇒ + val bytes = Array.ofDim[Byte](envelope.byteBuffer.limit) + envelope.byteBuffer.get(bytes) + pool.release(envelope) + new String(bytes, "utf-8") + } .runForeach { s ⇒ println(s) }.onFailure { @@ -283,9 +298,12 @@ object AeronStreamsApp { .map { n ⇒ val s = (fill + n.toString).takeRight(4) println(s) - s.getBytes("utf-8") + val envelope = pool.acquire() + envelope.byteBuffer.put(s.getBytes("utf-8")) + envelope.byteBuffer.flip() + envelope } - .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner, pool)) } def runStats(): Unit = { diff --git a/akka-remote/src/test/scala/akka/remote/artery/EnvelopeBufferSpec.scala b/akka-remote/src/test/scala/akka/remote/artery/EnvelopeBufferSpec.scala new file mode 100644 index 0000000000..dbd118e0ff --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/artery/EnvelopeBufferSpec.scala @@ -0,0 +1,169 @@ +package akka.remote.artery + +import java.nio.{ ByteBuffer, ByteOrder } + +import akka.testkit.AkkaSpec +import akka.util.ByteString + +class EnvelopeBufferSpec extends AkkaSpec { + + object TestCompressor extends LiteralCompressionTable { + val refToIdx = Map( + "compressable0" -> 0, + "compressable1" -> 1, + "reallylongcompressablestring" -> 2) + val idxToRef = refToIdx.map(_.swap) + + val serializerToIdx = Map( + "serializer0" -> 0, + "serializer1" -> 1) + val idxToSer = serializerToIdx.map(_.swap) + + val manifestToIdx = Map( + "manifest0" -> 0, + "manifest1" -> 1) + val idxToManifest = manifestToIdx.map(_.swap) + + override def compressActorRef(ref: String): Int = refToIdx.getOrElse(ref, -1) + override def decompressActorRef(idx: Int): String = idxToRef(idx) + override def compressSerializer(serializer: String): Int = serializerToIdx.getOrElse(serializer, -1) + override def decompressSerializer(idx: Int): String = idxToSer(idx) + override def compressClassManifest(manifest: String): Int = manifestToIdx.getOrElse(manifest, -1) + override def decompressClassManifest(idx: Int): String = idxToManifest(idx) + } + + "EnvelopeBuffer" must { + val headerIn = HeaderBuilder(TestCompressor) + val headerOut = HeaderBuilder(TestCompressor) + + val byteBuffer = ByteBuffer.allocate(1024).order(ByteOrder.LITTLE_ENDIAN) + val envelope = new EnvelopeBuffer(byteBuffer) + + "be able to encode and decode headers with compressed literals" in { + headerIn.version = 1 + headerIn.uid = 42 + headerIn.senderActorRef = "compressable0" + headerIn.recipientActorRef = "compressable1" + headerIn.serializer = "serializer0" + headerIn.classManifest = "manifest1" + + envelope.writeHeader(headerIn) + envelope.byteBuffer.position() should ===(EnvelopeBuffer.LiteralsSectionOffset) // Fully compressed header + + envelope.byteBuffer.flip() + envelope.parseHeader(headerOut) + + headerOut.version should ===(1) + headerOut.uid should ===(42) + headerOut.senderActorRef should ===("compressable0") + headerOut.recipientActorRef should ===("compressable1") + headerOut.serializer should ===("serializer0") + headerOut.classManifest should ===("manifest1") + } + + "be able to encode and decode headers with uncompressed literals" in { + headerIn.version = 1 + headerIn.uid = 42 + headerIn.senderActorRef = "uncompressable0" + headerIn.recipientActorRef = "uncompressable11" + headerIn.serializer = "uncompressable222" + headerIn.classManifest = "uncompressable3333" + + val expectedHeaderLength = + EnvelopeBuffer.LiteralsSectionOffset + // Constant header part + 2 + headerIn.senderActorRef.length + // Length field + literal + 2 + headerIn.recipientActorRef.length + // Length field + literal + 2 + headerIn.serializer.length + // Length field + literal + 2 + headerIn.classManifest.length // Length field + literal + + envelope.writeHeader(headerIn) + envelope.byteBuffer.position() should ===(expectedHeaderLength) + + envelope.byteBuffer.flip() + envelope.parseHeader(headerOut) + + headerOut.version should ===(1) + headerOut.uid should ===(42) + headerOut.senderActorRef should ===("uncompressable0") + headerOut.recipientActorRef should ===("uncompressable11") + headerOut.serializer should ===("uncompressable222") + headerOut.classManifest should ===("uncompressable3333") + } + + "be able to encode and decode headers with mixed literals" in { + headerIn.version = 1 + headerIn.uid = 42 + headerIn.senderActorRef = "reallylongcompressablestring" + headerIn.recipientActorRef = "uncompressable1" + headerIn.serializer = "longuncompressedserializer" + headerIn.classManifest = "manifest1" + + envelope.writeHeader(headerIn) + envelope.byteBuffer.position() should ===( + EnvelopeBuffer.LiteralsSectionOffset + + 2 + headerIn.recipientActorRef.length + + 2 + headerIn.serializer.length) + + envelope.byteBuffer.flip() + envelope.parseHeader(headerOut) + + headerOut.version should ===(1) + headerOut.uid should ===(42) + headerOut.senderActorRef should ===("reallylongcompressablestring") + headerOut.recipientActorRef should ===("uncompressable1") + headerOut.serializer should ===("longuncompressedserializer") + headerOut.classManifest should ===("manifest1") + + headerIn.version = 3 + headerIn.uid = Int.MinValue + headerIn.senderActorRef = "uncompressable0" + headerIn.recipientActorRef = "reallylongcompressablestring" + headerIn.serializer = "serializer0" + headerIn.classManifest = "longlonglongliteralmanifest" + + envelope.writeHeader(headerIn) + envelope.byteBuffer.position() should ===( + EnvelopeBuffer.LiteralsSectionOffset + + 2 + headerIn.senderActorRef.length + + 2 + headerIn.classManifest.length) + + envelope.byteBuffer.flip() + envelope.parseHeader(headerOut) + + headerOut.version should ===(3) + headerOut.uid should ===(Int.MinValue) + headerOut.senderActorRef should ===("uncompressable0") + headerOut.recipientActorRef should ===("reallylongcompressablestring") + headerOut.serializer should ===("serializer0") + headerOut.classManifest should ===("longlonglongliteralmanifest") + } + + "be able to encode and decode headers with mixed literals and payload" in { + val payload = ByteString("Hello Artery!") + + headerIn.version = 1 + headerIn.uid = 42 + headerIn.senderActorRef = "reallylongcompressablestring" + headerIn.recipientActorRef = "uncompressable1" + headerIn.serializer = "serializer1" + headerIn.classManifest = "manifest1" + + envelope.writeHeader(headerIn) + envelope.byteBuffer.put(payload.toByteBuffer) + envelope.byteBuffer.flip() + + envelope.parseHeader(headerOut) + + headerOut.version should ===(1) + headerOut.uid should ===(42) + headerOut.senderActorRef should ===("reallylongcompressablestring") + headerOut.recipientActorRef should ===("uncompressable1") + headerOut.serializer should ===("serializer1") + headerOut.classManifest should ===("manifest1") + + ByteString.fromByteBuffer(envelope.byteBuffer) should ===(payload) + } + + } + +}