diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala index 8e599f6376..28572c8418 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamConcistencySpec.scala @@ -55,9 +55,10 @@ abstract class AeronStreamConsistencySpec import AeronStreamConsistencySpec._ + val driver = MediaDriver.launchEmbedded() + val aeron = { val ctx = new Aeron.Context - val driver = MediaDriver.launchEmbedded() ctx.aeronDirectoryName(driver.aeronDirectoryName) Aeron.connect(ctx) } @@ -78,9 +79,12 @@ abstract class AeronStreamConsistencySpec s"aeron:udp?endpoint=${a.host.get}:${aeronPort(roleName)}" } + val streamId = 1 + override def afterAll(): Unit = { taskRunner.stop() aeron.close() + driver.close() super.afterAll() } @@ -89,8 +93,8 @@ abstract class AeronStreamConsistencySpec "start echo" in { runOn(second) { // just echo back - Source.fromGraph(new AeronSource(channel(second), aeron, taskRunner)) - .runWith(new AeronSink(channel(first), aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner)) } enterBarrier("echo-started") } @@ -103,7 +107,7 @@ abstract class AeronStreamConsistencySpec val killSwitch = KillSwitches.shared("test") val started = TestProbe() val startMsg = "0".getBytes("utf-8") - Source.fromGraph(new AeronSource(channel(first), aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner)) .via(killSwitch.flow) .runForeach { bytes ⇒ if (bytes.length == 1 && bytes(0) == startMsg(0)) @@ -124,14 +128,14 @@ abstract class AeronStreamConsistencySpec within(10.seconds) { Source(1 to 100).map(_ ⇒ startMsg) .throttle(1, 200.milliseconds, 1, ThrottleMode.Shaping) - .runWith(new AeronSink(channel(second), aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) started.expectMsg(Done) } Source(1 to totalMessages) .throttle(10000, 1.second, 1000, ThrottleMode.Shaping) .map { n ⇒ n.toString.getBytes("utf-8") } - .runWith(new AeronSink(channel(second), aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) Await.ready(done, 20.seconds) killSwitch.shutdown() diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala index 958c5d9174..d8cb7f34fb 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamLatencySpec.scala @@ -107,6 +107,8 @@ abstract class AeronStreamLatencySpec s"aeron:udp?endpoint=${a.host.get}:${aeronPort(roleName)}" } + val streamId = 1 + lazy val reporterExecutor = Executors.newFixedThreadPool(1) def reporter(name: String): TestRateReporter = { val r = new TestRateReporter(name) @@ -118,6 +120,7 @@ abstract class AeronStreamLatencySpec reporterExecutor.shutdown() taskRunner.stop() aeron.close() + driver.close() IoUtil.delete(new File(driver.aeronDirectoryName), true) runOn(first) { println(plots.plot50.csv(system.name + "50")) @@ -196,7 +199,7 @@ abstract class AeronStreamLatencySpec val killSwitch = KillSwitches.shared(testName) val started = TestProbe() val startMsg = "0".getBytes("utf-8") - Source.fromGraph(new AeronSource(channel(first), aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(first), streamId, aeron, taskRunner)) .via(killSwitch.flow) .runForeach { bytes ⇒ if (bytes.length == 1 && bytes(0) == startMsg(0)) @@ -217,7 +220,7 @@ abstract class AeronStreamLatencySpec within(10.seconds) { Source(1 to 50).map(_ ⇒ startMsg) .throttle(1, 200.milliseconds, 1, ThrottleMode.Shaping) - .runWith(new AeronSink(channel(second), aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) started.expectMsg(Done) } @@ -232,7 +235,7 @@ abstract class AeronStreamLatencySpec sendTimes.set(n - 1, System.nanoTime()) payload } - .runWith(new AeronSink(channel(second), aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) barrier.await((totalMessages / messageRate) + 10, SECONDS) } @@ -250,8 +253,8 @@ abstract class AeronStreamLatencySpec "start echo" in { runOn(second) { // just echo back - Source.fromGraph(new AeronSource(channel(second), aeron, taskRunner)) - .runWith(new AeronSink(channel(first), aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) + .runWith(new AeronSink(channel(first), streamId, aeron, taskRunner)) } enterBarrier("echo-started") } diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala index 3e68635e7f..f374461700 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/AeronStreamMaxThroughputSpec.scala @@ -113,6 +113,8 @@ abstract class AeronStreamMaxThroughputSpec s"aeron:udp?endpoint=${a.host.get}:${aeronPort(roleName)}" } + val streamId = 1 + lazy val reporterExecutor = Executors.newFixedThreadPool(1) def reporter(name: String): TestRateReporter = { val r = new TestRateReporter(name) @@ -124,6 +126,7 @@ abstract class AeronStreamMaxThroughputSpec reporterExecutor.shutdown() taskRunner.stop() aeron.close() + driver.close() runOn(second) { println(plot.csv(system.name)) } @@ -169,7 +172,7 @@ abstract class AeronStreamMaxThroughputSpec var count = 0L val done = TestLatch(1) val killSwitch = KillSwitches.shared(testName) - Source.fromGraph(new AeronSource(channel(second), aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel(second), streamId, aeron, taskRunner)) .via(killSwitch.flow) .runForeach { bytes ⇒ rep.onMessage(1, bytes.length) @@ -200,7 +203,7 @@ abstract class AeronStreamMaxThroughputSpec val t0 = System.nanoTime() Source.fromIterator(() ⇒ iterate(1, totalMessages)) .map { n ⇒ payload } - .runWith(new AeronSink(channel(second), aeron, taskRunner)) + .runWith(new AeronSink(channel(second), streamId, aeron, taskRunner)) printStats("sender") enterBarrier(testName + "-done") diff --git a/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala b/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala index 948f8cb856..f3427d7c92 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/AeronSink.scala @@ -51,7 +51,7 @@ object AeronSink { /** * @param channel eg. "aeron:udp?endpoint=localhost:40123" */ -class AeronSink(channel: String, aeron: Aeron, taskRunner: TaskRunner) extends GraphStage[SinkShape[AeronSink.Bytes]] { +class AeronSink(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner) extends GraphStage[SinkShape[AeronSink.Bytes]] { import AeronSink._ import TaskRunner._ @@ -62,7 +62,6 @@ class AeronSink(channel: String, aeron: Aeron, taskRunner: TaskRunner) extends G new GraphStageLogic(shape) with InHandler { private val buffer = new UnsafeBuffer(ByteBuffer.allocateDirect(128 * 1024)) - private val streamId = 10 private val pub = aeron.addPublication(channel, streamId) private val spinning = 1000 diff --git a/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala b/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala index 5231cc8f0f..c03aa69dea 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/AeronSource.scala @@ -61,7 +61,7 @@ object AeronSource { /** * @param channel eg. "aeron:udp?endpoint=localhost:40123" */ -class AeronSource(channel: String, aeron: Aeron, taskRunner: TaskRunner) extends GraphStage[SourceShape[AeronSource.Bytes]] { +class AeronSource(channel: String, streamId: Int, aeron: Aeron, taskRunner: TaskRunner) extends GraphStage[SourceShape[AeronSource.Bytes]] { import AeronSource._ import TaskRunner._ @@ -71,7 +71,6 @@ class AeronSource(channel: String, aeron: Aeron, taskRunner: TaskRunner) extends override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { - private val streamId = 10 private val sub = aeron.addSubscription(channel, streamId) private val spinning = 1000 private val yielding = 0 diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArterySubsystem.scala b/akka-remote/src/main/scala/akka/remote/artery/ArterySubsystem.scala index bf950c5591..49bbd62e4e 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArterySubsystem.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArterySubsystem.scala @@ -5,7 +5,6 @@ package akka.remote.artery import java.util.concurrent.ConcurrentHashMap - import akka.actor.{ ActorRef, Address, ExtendedActorSystem } import akka.event.{ Logging, LoggingAdapter } import akka.remote.EndpointManager.Send @@ -14,9 +13,9 @@ import akka.remote.{ DefaultMessageDispatcher, RemoteActorRef, RemoteActorRefPro import akka.stream.scaladsl.{ Sink, Source, SourceQueueWithComplete, Tcp } import akka.stream.{ ActorMaterializer, Materializer, OverflowStrategy } import akka.{ Done, NotUsed } - import scala.concurrent.duration._ import scala.concurrent.{ Await, Future } +import akka.dispatch.sysmsg.SystemMessage /** * INTERNAL API @@ -49,8 +48,7 @@ private[remote] class ArterySubsystem(_system: ExtendedActorSystem, _provider: R system, materializer, provider, - AkkaPduProtobufCodec, - new DefaultMessageDispatcher(system, provider, log)) + AkkaPduProtobufCodec) transport.start() } @@ -100,13 +98,24 @@ private[akka] class Association( val materializer: Materializer, val remoteAddress: Address, val transport: Transport) { + @volatile private[this] var queue: SourceQueueWithComplete[Send] = _ - private[this] val sink: Sink[Send, Any] = transport.outbound(remoteAddress) + @volatile private[this] var systemMessageQueue: SourceQueueWithComplete[Send] = _ def send(message: Any, senderOption: Option[ActorRef], recipient: RemoteActorRef): Unit = { // TODO: lookup subchannel // FIXME: Use a different envelope than the old Send, but make sure the new is handled by deadLetters properly - queue.offer(Send(message, senderOption, recipient, None)) + message match { + case _: SystemMessage | _: SystemMessageDelivery.SystemMessageReply ⇒ + implicit val ec = materializer.executionContext + systemMessageQueue.offer(Send(message, senderOption, recipient, None)).onFailure { + case e ⇒ + // FIXME proper error handling, and quarantining + println(s"# System message dropped, due to $e") // FIXME + } + case _ ⇒ + queue.offer(Send(message, senderOption, recipient, None)) + } } def quarantine(uid: Option[Int]): Unit = () @@ -114,7 +123,11 @@ private[akka] class Association( // Idempotent def associate(): Unit = { if (queue eq null) - queue = Source.queue(256, OverflowStrategy.dropBuffer).to(sink).run()(materializer) + queue = Source.queue(256, OverflowStrategy.dropBuffer) + .to(transport.outbound(remoteAddress)).run()(materializer) + if (systemMessageQueue eq null) + systemMessageQueue = Source.queue(256, OverflowStrategy.dropBuffer) + .to(transport.outboundSystemMessage(remoteAddress)).run()(materializer) } } diff --git a/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala b/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala new file mode 100644 index 0000000000..c8e3036e10 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/SystemMessageDelivery.scala @@ -0,0 +1,310 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ +package akka.remote.artery + +import java.util.ArrayDeque + +import scala.annotation.tailrec +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.concurrent.duration._ +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +import akka.Done +import akka.actor.ActorRef +import akka.actor.Address +import akka.remote.EndpointManager.Send +import akka.remote.artery.Transport.InboundEnvelope +import akka.stream.Attributes +import akka.stream.FlowShape +import akka.stream.Inlet +import akka.stream.Outlet +import akka.stream.stage.AsyncCallback +import akka.stream.stage.GraphStage +import akka.stream.stage.GraphStageLogic +import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.stage.InHandler +import akka.stream.stage.OutHandler +import akka.stream.stage.TimerGraphStageLogic + +/** + * INTERNAL API + */ +private[akka] object SystemMessageDelivery { + // FIXME serialization of these messages + final case class SystemMessageEnvelope(message: AnyRef, seqNo: Long, ackReplyTo: ActorRef) + sealed trait SystemMessageReply + final case class Ack(seq: Long, from: Address) extends SystemMessageReply + final case class Nack(seq: Long, from: Address) extends SystemMessageReply + + private case object ResendTick +} + +/** + * INTERNAL API + */ +private[akka] class SystemMessageDelivery( + replyJunction: SystemMessageReplyJunction.Junction, + resendInterval: FiniteDuration, + localAddress: Address, + remoteAddress: Address, + ackRecipient: ActorRef) + extends GraphStage[FlowShape[Send, Send]] { + + import SystemMessageDelivery._ + import SystemMessageReplyJunction._ + + val in: Inlet[Send] = Inlet("SystemMessageDelivery.in") + val out: Outlet[Send] = Outlet("SystemMessageDelivery.out") + override val shape: FlowShape[Send, Send] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new TimerGraphStageLogic(shape) with InHandler with OutHandler { + + var registered = false + var seqNo = 0L // sequence number for the first message will be 1 + val unacknowledged = new ArrayDeque[Send] + var resending = new ArrayDeque[Send] + var resendingFromSeqNo = -1L + var stopping = false + + override def preStart(): Unit = { + this.schedulePeriodically(ResendTick, resendInterval) + def filter(env: InboundEnvelope): Boolean = + env.message match { + case Ack(_, from) if from == remoteAddress ⇒ true + case Nack(_, from) if from == remoteAddress ⇒ true + case _ ⇒ false + } + + implicit val ec = materializer.executionContext + replyJunction.addReplyInterest(filter, ackCallback).foreach { + getAsyncCallback[Done] { _ ⇒ + registered = true + if (isAvailable(out)) + pull(in) // onPull from downstream already called + }.invoke + } + + replyJunction.stopped.onComplete { + getAsyncCallback[Try[Done]] { + // FIXME quarantine + case Success(_) ⇒ completeStage() + case Failure(cause) ⇒ failStage(cause) + }.invoke + } + } + + override def postStop(): Unit = { + replyJunction.removeReplyInterest(ackCallback) + } + + override def onUpstreamFinish(): Unit = { + if (unacknowledged.isEmpty) + super.onUpstreamFinish() + else + stopping = true + } + + override protected def onTimer(timerKey: Any): Unit = + timerKey match { + case ResendTick ⇒ + if (resending.isEmpty && !unacknowledged.isEmpty) { + resending = unacknowledged.clone() + tryResend() + } + } + + val ackCallback = getAsyncCallback[SystemMessageReply] { reply ⇒ + reply match { + case Ack(n, _) ⇒ + ack(n) + case Nack(n, _) ⇒ + ack(n) + if (n > resendingFromSeqNo) + resending = unacknowledged.clone() + tryResend() + } + } + + private def ack(n: Long): Unit = { + if (n > seqNo) + throw new IllegalArgumentException(s"Unexpected ack $n, when highest sent seqNo is $seqNo") + clearUnacknowledged(n) + } + + @tailrec private def clearUnacknowledged(ackedSeqNo: Long): Unit = { + if (!unacknowledged.isEmpty && + unacknowledged.peek().message.asInstanceOf[SystemMessageEnvelope].seqNo <= ackedSeqNo) { + unacknowledged.removeFirst() + if (stopping && unacknowledged.isEmpty) + completeStage() + else + clearUnacknowledged(ackedSeqNo) + } + } + + private def tryResend(): Unit = { + if (isAvailable(out) && !resending.isEmpty) + push(out, resending.poll()) + } + + // InHandler + override def onPush(): Unit = { + grab(in) match { + case s @ Send(reply: SystemMessageReply, _, _, _) ⇒ + // pass through + if (isAvailable(out)) + push(out, s) + else { + // it's ok to drop the replies, but we can try + resending.offer(s) + } + + case s @ Send(msg: AnyRef, _, _, _) ⇒ + seqNo += 1 + val sendMsg = s.copy(message = SystemMessageEnvelope(msg, seqNo, ackRecipient)) + // FIXME quarantine if unacknowledged is full + unacknowledged.offer(sendMsg) + if (resending.isEmpty && isAvailable(out)) + push(out, sendMsg) + else { + resending.offer(sendMsg) + tryResend() + } + } + } + + // OutHandler + override def onPull(): Unit = { + if (registered) { // otherwise it will be pulled after replyJunction.addReplyInterest + if (resending.isEmpty && !hasBeenPulled(in) && !stopping) + pull(in) + else + tryResend() + } + } + + setHandlers(in, out, this) + } +} + +/** + * INTERNAL API + */ +private[akka] class SystemMessageAcker(localAddress: Address) extends GraphStage[FlowShape[InboundEnvelope, InboundEnvelope]] { + import SystemMessageDelivery._ + + val in: Inlet[InboundEnvelope] = Inlet("SystemMessageAcker.in") + val out: Outlet[InboundEnvelope] = Outlet("SystemMessageAcker.out") + override val shape: FlowShape[InboundEnvelope, InboundEnvelope] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + + var seqNo = 1L + + // InHandler + override def onPush(): Unit = { + grab(in) match { + case env @ InboundEnvelope(_, _, sysEnv @ SystemMessageEnvelope(_, n, ackReplyTo), _) ⇒ + if (n == seqNo) { + ackReplyTo.tell(Ack(n, localAddress), ActorRef.noSender) + seqNo += 1 + val unwrapped = env.copy(message = sysEnv.message) + push(out, unwrapped) + } else if (n < seqNo) { + ackReplyTo.tell(Ack(n, localAddress), ActorRef.noSender) + pull(in) + } else { + ackReplyTo.tell(Nack(seqNo - 1, localAddress), ActorRef.noSender) + pull(in) + } + case env ⇒ + // messages that don't need acking + push(out, env) + } + + } + + // OutHandler + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } +} + +/** + * INTERNAL API + */ +private[akka] object SystemMessageReplyJunction { + import SystemMessageDelivery._ + + trait Junction { + def addReplyInterest(filter: InboundEnvelope ⇒ Boolean, replyCallback: AsyncCallback[SystemMessageReply]): Future[Done] + def removeReplyInterest(callback: AsyncCallback[SystemMessageReply]): Unit + def stopped: Future[Done] + } +} + +/** + * INTERNAL API + */ +private[akka] class SystemMessageReplyJunction + extends GraphStageWithMaterializedValue[FlowShape[InboundEnvelope, InboundEnvelope], SystemMessageReplyJunction.Junction] { + import SystemMessageReplyJunction._ + import SystemMessageDelivery._ + + val in: Inlet[InboundEnvelope] = Inlet("SystemMessageReplyJunction.in") + val out: Outlet[InboundEnvelope] = Outlet("SystemMessageReplyJunction.out") + override val shape: FlowShape[InboundEnvelope, InboundEnvelope] = FlowShape(in, out) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val logic = new GraphStageLogic(shape) with InHandler with OutHandler with Junction { + + private var replyHandlers: Vector[(InboundEnvelope ⇒ Boolean, AsyncCallback[SystemMessageReply])] = Vector.empty + private val stoppedPromise = Promise[Done]() + + override def postStop(): Unit = stoppedPromise.success(Done) + + // InHandler + override def onPush(): Unit = { + grab(in) match { + case env @ InboundEnvelope(_, _, reply: SystemMessageReply, _) ⇒ + replyHandlers.foreach { + case (f, callback) ⇒ + if (f(env)) + callback.invoke(reply) + } + pull(in) + case env ⇒ + push(out, env) + } + } + + // OutHandler + override def onPull(): Unit = pull(in) + + override def addReplyInterest(filter: InboundEnvelope ⇒ Boolean, replyCallback: AsyncCallback[SystemMessageReply]): Future[Done] = { + val p = Promise[Done]() + getAsyncCallback[Unit](_ ⇒ { + replyHandlers :+= (filter -> replyCallback) + p.success(Done) + }).invoke(()) + p.future + } + + override def removeReplyInterest(callback: AsyncCallback[SystemMessageReply]): Unit = { + replyHandlers = replyHandlers.filterNot { case (_, c) ⇒ c == callback } + } + + override def stopped: Future[Done] = stoppedPromise.future + + setHandlers(in, out, this) + } + (logic, logic) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/artery/Transport.scala b/akka-remote/src/main/scala/akka/remote/artery/Transport.scala index 6dd7d9a7b4..fb9563def7 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Transport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Transport.scala @@ -4,6 +4,8 @@ package akka.remote.artery +import scala.concurrent.duration._ +import akka.actor.Props import scala.concurrent.duration._ import java.net.InetSocketAddress import java.nio.ByteOrder @@ -29,6 +31,28 @@ import io.aeron.AvailableImageHandler import io.aeron.Image import io.aeron.UnavailableImageHandler import io.aeron.exceptions.ConductorServiceTimeoutException +import akka.actor.LocalRef +import akka.actor.InternalActorRef +import akka.dispatch.sysmsg.SystemMessage +import akka.actor.PossiblyHarmful +import akka.actor.RepointableRef +import akka.actor.ActorSelectionMessage +import akka.remote.RemoteRef +import akka.actor.ActorSelection +import akka.actor.ActorRef +import akka.stream.scaladsl.Keep + +/** + * INTERNAL API + */ +private[akka] object Transport { + // FIXME avoid allocating this envelope? + final case class InboundEnvelope( + recipient: InternalActorRef, + recipientAddress: Address, + message: AnyRef, + senderOption: Option[ActorRef]) +} /** * INTERNAL API @@ -39,14 +63,37 @@ private[akka] class Transport( val system: ExtendedActorSystem, val materializer: Materializer, val provider: RemoteActorRefProvider, - val codec: AkkaPduCodec, - val inboundDispatcher: InboundMessageDispatcher) { + val codec: AkkaPduCodec) { + import Transport._ private val log: LoggingAdapter = Logging(system.eventStream, getClass.getName) + private val remoteDaemon = provider.remoteDaemon private implicit val mat = materializer // TODO support port 0 private val inboundChannel = s"aeron:udp?endpoint=${localAddress.host.get}:${localAddress.port.get}" + private def outboundChannel(a: Address) = s"aeron:udp?endpoint=${a.host.get}:${a.port.get}" + private val systemMessageStreamId = 1 + private val ordinaryStreamId = 3 + + private val systemMessageResendInterval: FiniteDuration = 1.second // FIXME config + + private var systemMessageReplyJunction: SystemMessageReplyJunction.Junction = _ + + // Need an ActorRef that is passed in the `SystemMessageEnvelope.ackReplyTo`. + // Those messages are not actually handled by this actor, but intercepted by the + // SystemMessageReplyJunction stage. + private val systemMessageReplyRecepient = system.systemActorOf(Props.empty, "systemMessageReplyTo") + + private val driver = { + // TODO also support external media driver + val driverContext = new MediaDriver.Context + // FIXME settings from config + driverContext.clientLivenessTimeoutNs(SECONDS.toNanos(10)) + driverContext.imageLivenessTimeoutNs(SECONDS.toNanos(10)) + driverContext.driverTimeoutMs(SECONDS.toNanos(10)) + MediaDriver.launchEmbedded(driverContext) + } private val aeron = { val ctx = new Aeron.Context @@ -76,13 +123,6 @@ private[akka] class Transport( } } }) - // TODO also support external media driver - val driverContext = new MediaDriver.Context - // FIXME settings from config - driverContext.clientLivenessTimeoutNs(SECONDS.toNanos(10)) - driverContext.imageLivenessTimeoutNs(SECONDS.toNanos(10)) - driverContext.driverTimeoutMs(SECONDS.toNanos(10)) - val driver = MediaDriver.launchEmbedded(driverContext) ctx.aeronDirectoryName(driver.aeronDirectoryName) Aeron.connect(ctx) @@ -92,7 +132,13 @@ private[akka] class Transport( def start(): Unit = { taskRunner.start() - Source.fromGraph(new AeronSource(inboundChannel, aeron, taskRunner)) + systemMessageReplyJunction = Source.fromGraph(new AeronSource(inboundChannel, systemMessageStreamId, aeron, taskRunner)) + .async // FIXME use dedicated dispatcher for AeronSource + .map(ByteString.apply) // TODO we should use ByteString all the way + .viaMat(inboundSystemMessageFlow)(Keep.right) + .to(Sink.ignore) + .run() + Source.fromGraph(new AeronSource(inboundChannel, ordinaryStreamId, aeron, taskRunner)) .async // FIXME use dedicated dispatcher for AeronSource .map(ByteString.apply) // TODO we should use ByteString all the way .via(inboundFlow) @@ -103,17 +149,26 @@ private[akka] class Transport( // FIXME stop the AeronSource first? taskRunner.stop() aeron.close() + driver.close() Future.successful(Done) } val killSwitch: SharedKillSwitch = KillSwitches.shared("transportKillSwitch") def outbound(remoteAddress: Address): Sink[Send, Any] = { - val outboundChannel = s"aeron:udp?endpoint=${remoteAddress.host.get}:${remoteAddress.port.get}" Flow.fromGraph(killSwitch.flow[Send]) .via(encoder) .map(_.toArray) // TODO we should use ByteString all the way - .to(new AeronSink(outboundChannel, aeron, taskRunner)) + .to(new AeronSink(outboundChannel(remoteAddress), ordinaryStreamId, aeron, taskRunner)) + } + + def outboundSystemMessage(remoteAddress: Address): Sink[Send, Any] = { + Flow.fromGraph(killSwitch.flow[Send]) + .via(new SystemMessageDelivery(systemMessageReplyJunction, systemMessageResendInterval, + localAddress, remoteAddress, systemMessageReplyRecepient)) + .via(encoder) + .map(_.toArray) // TODO we should use ByteString all the way + .to(new AeronSink(outboundChannel(remoteAddress), systemMessageStreamId, aeron, taskRunner)) } // TODO: Try out parallelized serialization (mapAsync) for performance @@ -141,14 +196,86 @@ private[akka] class Transport( pdu } - val messageDispatcher: Sink[AkkaPduCodec.Message, Any] = Sink.foreach[AkkaPduCodec.Message] { m ⇒ - inboundDispatcher.dispatch(m.recipient, m.recipientAddress, m.serializedMessage, m.senderOption) + val messageDispatcher: Sink[InboundEnvelope, Future[Done]] = Sink.foreach[InboundEnvelope] { m ⇒ + dispatchInboundMessage(m.recipient, m.recipientAddress, m.message, m.senderOption) } + val deserializer: Flow[AkkaPduCodec.Message, InboundEnvelope, NotUsed] = + Flow[AkkaPduCodec.Message].map { m ⇒ + InboundEnvelope( + m.recipient, + m.recipientAddress, + MessageSerializer.deserialize(system, m.serializedMessage), + m.senderOption) + } + val inboundFlow: Flow[ByteString, ByteString, NotUsed] = { Flow.fromSinkAndSource( - decoder.to(messageDispatcher), + decoder.via(deserializer).to(messageDispatcher), Source.maybe[ByteString].via(killSwitch.flow)) } + val inboundSystemMessageFlow: Flow[ByteString, ByteString, SystemMessageReplyJunction.Junction] = { + Flow.fromSinkAndSourceMat( + decoder.via(deserializer) + .via(new SystemMessageAcker(localAddress)) + .viaMat(new SystemMessageReplyJunction)(Keep.right) + .to(messageDispatcher), + Source.maybe[ByteString].via(killSwitch.flow))((a, b) ⇒ a) + } + + private def dispatchInboundMessage(recipient: InternalActorRef, + recipientAddress: Address, + message: AnyRef, + senderOption: Option[ActorRef]): Unit = { + + import provider.remoteSettings._ + + val sender: ActorRef = senderOption.getOrElse(system.deadLetters) + val originalReceiver = recipient.path + + def msgLog = s"RemoteMessage: [$message] to [$recipient]<+[$originalReceiver] from [$sender()]" + + recipient match { + + case `remoteDaemon` ⇒ + if (UntrustedMode) log.debug("dropping daemon message in untrusted mode") + else { + if (LogReceive) log.debug("received daemon message {}", msgLog) + remoteDaemon ! message + } + + case l @ (_: LocalRef | _: RepointableRef) if l.isLocal ⇒ + if (LogReceive) log.debug("received local message {}", msgLog) + message match { + case sel: ActorSelectionMessage ⇒ + if (UntrustedMode && (!TrustedSelectionPaths.contains(sel.elements.mkString("/", "/", "")) || + sel.msg.isInstanceOf[PossiblyHarmful] || l != provider.rootGuardian)) + log.debug("operating in UntrustedMode, dropping inbound actor selection to [{}], " + + "allow it by adding the path to 'akka.remote.trusted-selection-paths' configuration", + sel.elements.mkString("/", "/", "")) + else + // run the receive logic for ActorSelectionMessage here to make sure it is not stuck on busy user actor + ActorSelection.deliverSelection(l, sender, sel) + case msg: PossiblyHarmful if UntrustedMode ⇒ + log.debug("operating in UntrustedMode, dropping inbound PossiblyHarmful message of type [{}]", msg.getClass.getName) + case msg: SystemMessage ⇒ l.sendSystemMessage(msg) + case msg ⇒ l.!(msg)(sender) + } + + case r @ (_: RemoteRef | _: RepointableRef) if !r.isLocal && !UntrustedMode ⇒ + if (LogReceive) log.debug("received remote-destined message {}", msgLog) + if (provider.transport.addresses(recipientAddress)) + // if it was originally addressed to us but is in fact remote from our point of view (i.e. remote-deployed) + r.!(message)(sender) + else + log.error("dropping message [{}] for non-local recipient [{}] arriving at [{}] inbound addresses are [{}]", + message.getClass, r, recipientAddress, provider.transport.addresses.mkString(", ")) + + case r ⇒ log.error("dropping message [{}] for unknown recipient [{}] arriving at [{}] inbound addresses are [{}]", + message.getClass, r, recipientAddress, provider.transport.addresses.mkString(", ")) + + } + } + } diff --git a/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala b/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala index ce337bdfa3..834e124a5a 100644 --- a/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala +++ b/akka-remote/src/test/scala/akka/remote/artery/AeronStreamsApp.scala @@ -1,3 +1,6 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ package akka.remote.artery import scala.concurrent.duration._ @@ -30,6 +33,7 @@ object AeronStreamsApp { val channel1 = "aeron:udp?endpoint=localhost:40123" val channel2 = "aeron:udp?endpoint=localhost:40124" + val streamId = 1 val throughputN = 10000000 val latencyRate = 10000 // per second val latencyN = 10 * latencyRate @@ -153,7 +157,7 @@ object AeronStreamsApp { var t0 = System.nanoTime() var count = 0L var payloadSize = 0L - Source.fromGraph(new AeronSource(channel1, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) .map { bytes ⇒ r.onMessage(1, bytes.length) bytes @@ -191,19 +195,19 @@ object AeronStreamsApp { r.onMessage(1, payload.length) payload } - .runWith(new AeronSink(channel1, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) } def runEchoReceiver(): Unit = { // just echo back on channel2 reporterExecutor.execute(reporter) val r = reporter - Source.fromGraph(new AeronSource(channel1, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) .map { bytes ⇒ r.onMessage(1, bytes.length) bytes } - .runWith(new AeronSink(channel2, aeron, taskRunner)) + .runWith(new AeronSink(channel2, streamId, aeron, taskRunner)) } def runEchoSender(): Unit = { @@ -215,7 +219,7 @@ object AeronStreamsApp { var repeat = 3 val count = new AtomicInteger var t0 = System.nanoTime() - Source.fromGraph(new AeronSource(channel2, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel2, streamId, aeron, taskRunner)) .map { bytes ⇒ r.onMessage(1, bytes.length) bytes @@ -250,7 +254,7 @@ object AeronStreamsApp { sendTimes.set(n - 1, System.nanoTime()) payload } - .runWith(new AeronSink(channel1, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) barrier.await() } @@ -260,7 +264,7 @@ object AeronStreamsApp { def runDebugReceiver(): Unit = { import system.dispatcher - Source.fromGraph(new AeronSource(channel1, aeron, taskRunner)) + Source.fromGraph(new AeronSource(channel1, streamId, aeron, taskRunner)) .map(bytes ⇒ new String(bytes, "utf-8")) .runForeach { s ⇒ println(s) @@ -281,7 +285,7 @@ object AeronStreamsApp { println(s) s.getBytes("utf-8") } - .runWith(new AeronSink(channel1, aeron, taskRunner)) + .runWith(new AeronSink(channel1, streamId, aeron, taskRunner)) } def runStats(): Unit = { diff --git a/akka-remote/src/test/scala/akka/remote/artery/RemoteSendConsistencySpec.scala b/akka-remote/src/test/scala/akka/remote/artery/RemoteSendConsistencySpec.scala index b6c320206c..c94e21ad64 100644 --- a/akka-remote/src/test/scala/akka/remote/artery/RemoteSendConsistencySpec.scala +++ b/akka-remote/src/test/scala/akka/remote/artery/RemoteSendConsistencySpec.scala @@ -1,3 +1,6 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ package akka.remote.artery import scala.concurrent.duration._ @@ -6,23 +9,29 @@ import akka.testkit.{ AkkaSpec, ImplicitSender } import com.typesafe.config.ConfigFactory import RemoteSendConsistencySpec._ import akka.actor.Actor.Receive +import akka.testkit.SocketUtil object RemoteSendConsistencySpec { - val commonConfig = """ + val Seq(portA, portB) = SocketUtil.temporaryServerAddresses(2, "localhost", udp = true).map(_.getPort) + + val commonConfig = ConfigFactory.parseString(s""" akka { actor.provider = "akka.remote.RemoteActorRefProvider" remote.artery.enabled = on remote.artery.hostname = localhost + remote.artery.port = $portA } - """ + """) + + val configB = ConfigFactory.parseString(s"akka.remote.artery.port = $portB") + .withFallback(commonConfig) } class RemoteSendConsistencySpec extends AkkaSpec(commonConfig) with ImplicitSender { - val configB = ConfigFactory.parseString("akka.remote.artery.port = 20201") - val systemB = ActorSystem("systemB", configB.withFallback(system.settings.config)) + val systemB = ActorSystem("systemB", RemoteSendConsistencySpec.configB) val addressB = systemB.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress println(addressB) val rootB = RootActorPath(addressB) diff --git a/akka-remote/src/test/scala/akka/remote/artery/SystemMessageDeliverySpec.scala b/akka-remote/src/test/scala/akka/remote/artery/SystemMessageDeliverySpec.scala new file mode 100644 index 0000000000..a95802d69b --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/artery/SystemMessageDeliverySpec.scala @@ -0,0 +1,282 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ +package akka.remote.artery + +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.concurrent.duration._ +import scala.concurrent.forkjoin.ThreadLocalRandom + +import akka.Done +import akka.NotUsed +import akka.actor.Actor +import akka.actor.ActorIdentity +import akka.actor.ActorRef +import akka.actor.ActorSystem +import akka.actor.ExtendedActorSystem +import akka.actor.Identify +import akka.actor.InternalActorRef +import akka.actor.PoisonPill +import akka.actor.Props +import akka.actor.RootActorPath +import akka.actor.Stash +import akka.remote.EndpointManager.Send +import akka.remote.RemoteActorRef +import akka.remote.artery.SystemMessageDelivery._ +import akka.remote.artery.Transport.InboundEnvelope +import akka.stream.ActorMaterializer +import akka.stream.ActorMaterializerSettings +import akka.stream.ThrottleMode +import akka.stream.scaladsl.Flow +import akka.stream.scaladsl.Sink +import akka.stream.scaladsl.Source +import akka.stream.stage.AsyncCallback +import akka.stream.testkit.TestSubscriber +import akka.stream.testkit.scaladsl.TestSink +import akka.testkit.AkkaSpec +import akka.testkit.ImplicitSender +import akka.testkit.SocketUtil +import akka.testkit.TestActors +import akka.testkit.TestProbe +import com.typesafe.config.ConfigFactory + +object SystemMessageDeliverySpec { + + val Seq(portA, portB) = SocketUtil.temporaryServerAddresses(2, "localhost", udp = true).map(_.getPort) + + val commonConfig = ConfigFactory.parseString(s""" + akka { + actor.provider = "akka.remote.RemoteActorRefProvider" + remote.artery.enabled = on + remote.artery.hostname = localhost + remote.artery.port = $portA + } + akka.actor.serialize-creators = off + akka.actor.serialize-messages = off + """) + + val configB = ConfigFactory.parseString(s"akka.remote.artery.port = $portB") + .withFallback(commonConfig) + + class TestReplyJunction(sendCallbackTo: ActorRef) extends SystemMessageReplyJunction.Junction { + + def addReplyInterest(filter: InboundEnvelope ⇒ Boolean, replyCallback: AsyncCallback[SystemMessageReply]): Future[Done] = { + sendCallbackTo ! replyCallback + Future.successful(Done) + } + + override def removeReplyInterest(callback: AsyncCallback[SystemMessageReply]): Unit = () + + override def stopped: Future[Done] = Promise[Done]().future + } + + def replyConnectorProps(dropRate: Double): Props = + Props(new ReplyConnector(dropRate)) + + class ReplyConnector(dropRate: Double) extends Actor with Stash { + override def receive = { + case callback: AsyncCallback[SystemMessageReply] @unchecked ⇒ + context.become(active(callback)) + unstashAll() + case _ ⇒ stash() + } + + def active(callback: AsyncCallback[SystemMessageReply]): Receive = { + case reply: SystemMessageReply ⇒ + if (ThreadLocalRandom.current().nextDouble() >= dropRate) + callback.invoke(reply) + } + } + +} + +class SystemMessageDeliverySpec extends AkkaSpec(SystemMessageDeliverySpec.commonConfig) with ImplicitSender { + import SystemMessageDeliverySpec._ + + val addressA = system.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + val systemB = ActorSystem("systemB", configB) + val addressB = systemB.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + val rootB = RootActorPath(addressB) + val matSettings = ActorMaterializerSettings(system).withFuzzing(true) + implicit val mat = ActorMaterializer(matSettings)(system) + + override def afterTermination(): Unit = shutdown(systemB) + + def setupManualCallback(ackRecipient: ActorRef, resendInterval: FiniteDuration, + dropSeqNumbers: Vector[Long], sendCount: Int): (TestSubscriber.Probe[String], AsyncCallback[SystemMessageReply]) = { + val callbackProbe = TestProbe() + val replyJunction = new TestReplyJunction(callbackProbe.ref) + + val sink = + send(sendCount, resendInterval, replyJunction, ackRecipient) + .via(drop(dropSeqNumbers)) + .via(inbound) + .map(_.message.asInstanceOf[String]) + .runWith(TestSink.probe) + + val callback = callbackProbe.expectMsgType[AsyncCallback[SystemMessageReply]] + (sink, callback) + } + + def send(sendCount: Int, resendInterval: FiniteDuration, replyJunction: SystemMessageReplyJunction.Junction, + ackRecipient: ActorRef): Source[Send, NotUsed] = { + val remoteRef = null.asInstanceOf[RemoteActorRef] // not used + Source(1 to sendCount) + .map(n ⇒ Send("msg-" + n, None, remoteRef, None)) + .via(new SystemMessageDelivery(replyJunction, resendInterval, addressA, addressB, ackRecipient)) + } + + def inbound: Flow[Send, InboundEnvelope, NotUsed] = { + val recipient = null.asInstanceOf[InternalActorRef] // not used + Flow[Send] + .map { + case Send(sysEnv: SystemMessageEnvelope, _, _, _) ⇒ + InboundEnvelope(recipient, addressB, sysEnv, None) + } + .async + .via(new SystemMessageAcker(addressB)) + } + + def drop(dropSeqNumbers: Vector[Long]): Flow[Send, Send, NotUsed] = { + Flow[Send] + .statefulMapConcat(() ⇒ { + var dropping = dropSeqNumbers + + { + case s @ Send(SystemMessageEnvelope(_, seqNo, _), _, _, _) ⇒ + val i = dropping.indexOf(seqNo) + if (i >= 0) { + dropping = dropping.updated(i, -1L) + Nil + } else + List(s) + } + }) + } + + def randomDrop[T](dropRate: Double): Flow[T, T, NotUsed] = Flow[T].mapConcat { elem ⇒ + if (ThreadLocalRandom.current().nextDouble() < dropRate) Nil + else List(elem) + } + + "System messages" must { + + "be delivered with real actors" in { + val actorOnSystemB = systemB.actorOf(TestActors.echoActorProps, "echo") + + val remoteRef = { + system.actorSelection(rootB / "user" / "echo") ! Identify(None) + expectMsgType[ActorIdentity].ref.get + } + + watch(remoteRef) + remoteRef ! PoisonPill + expectTerminated(remoteRef) + } + + "be resent when some in the middle are lost" in { + val ackRecipient = TestProbe() + val (sink, replyCallback) = + setupManualCallback(ackRecipient.ref, resendInterval = 60.seconds, dropSeqNumbers = Vector(3L, 4L), sendCount = 5) + + sink.request(100) + sink.expectNext("msg-1") + sink.expectNext("msg-2") + ackRecipient.expectMsg(Ack(1L, addressB)) + ackRecipient.expectMsg(Ack(2L, addressB)) + // 3 and 4 was dropped + ackRecipient.expectMsg(Nack(2L, addressB)) + sink.expectNoMsg(100.millis) // 3 was dropped + replyCallback.invoke(Nack(2L, addressB)) + // resending 3, 4, 5 + sink.expectNext("msg-3") + ackRecipient.expectMsg(Ack(3L, addressB)) + sink.expectNext("msg-4") + ackRecipient.expectMsg(Ack(4L, addressB)) + sink.expectNext("msg-5") + ackRecipient.expectMsg(Ack(5L, addressB)) + ackRecipient.expectNoMsg(100.millis) + replyCallback.invoke(Ack(5L, addressB)) + sink.expectComplete() + } + + "be resent when first is lost" in { + val ackRecipient = TestProbe() + val (sink, replyCallback) = + setupManualCallback(ackRecipient.ref, resendInterval = 60.seconds, dropSeqNumbers = Vector(1L), sendCount = 3) + + sink.request(100) + ackRecipient.expectMsg(Nack(0L, addressB)) // from receiving 2 + ackRecipient.expectMsg(Nack(0L, addressB)) // from receiving 3 + sink.expectNoMsg(100.millis) // 1 was dropped + replyCallback.invoke(Nack(0L, addressB)) + replyCallback.invoke(Nack(0L, addressB)) + // resending 1, 2, 3 + sink.expectNext("msg-1") + ackRecipient.expectMsg(Ack(1L, addressB)) + sink.expectNext("msg-2") + ackRecipient.expectMsg(Ack(2L, addressB)) + sink.expectNext("msg-3") + ackRecipient.expectMsg(Ack(3L, addressB)) + replyCallback.invoke(Ack(3L, addressB)) + sink.expectComplete() + } + + "be resent when last is lost" in { + val ackRecipient = TestProbe() + val (sink, replyCallback) = + setupManualCallback(ackRecipient.ref, resendInterval = 1.second, dropSeqNumbers = Vector(3L), sendCount = 3) + + sink.request(100) + sink.expectNext("msg-1") + ackRecipient.expectMsg(Ack(1L, addressB)) + replyCallback.invoke(Ack(1L, addressB)) + sink.expectNext("msg-2") + ackRecipient.expectMsg(Ack(2L, addressB)) + replyCallback.invoke(Ack(2L, addressB)) + sink.expectNoMsg(200.millis) // 3 was dropped + // resending 3 due to timeout + sink.expectNext("msg-3") + ackRecipient.expectMsg(Ack(3L, addressB)) + replyCallback.invoke(Ack(3L, addressB)) + sink.expectComplete() + } + + "deliver all during stress and random dropping" in { + val N = 10000 + val dropRate = 0.1 + val replyConnector = system.actorOf(replyConnectorProps(dropRate)) + val replyJunction = new TestReplyJunction(replyConnector) + + val output = + send(N, 1.second, replyJunction, replyConnector) + .via(randomDrop(dropRate)) + .via(inbound) + .map(_.message.asInstanceOf[String]) + .runWith(Sink.seq) + + Await.result(output, 20.seconds) should ===((1 to N).map("msg-" + _).toVector) + } + + "deliver all during throttling and random dropping" in { + val N = 500 + val dropRate = 0.1 + val replyConnector = system.actorOf(replyConnectorProps(dropRate)) + val replyJunction = new TestReplyJunction(replyConnector) + + val output = + send(N, 1.second, replyJunction, replyConnector) + .throttle(200, 1.second, 10, ThrottleMode.shaping) + .via(randomDrop(dropRate)) + .via(inbound) + .map(_.message.asInstanceOf[String]) + .runWith(Sink.seq) + + Await.result(output, 20.seconds) should ===((1 to N).map("msg-" + _).toVector) + } + + } + +}