diff --git a/akka-multi-node-testkit/src/main/scala/akka/remote/testconductor/Conductor.scala b/akka-multi-node-testkit/src/main/scala/akka/remote/testconductor/Conductor.scala index 5ac5c688a4..cb1114fb12 100644 --- a/akka-multi-node-testkit/src/main/scala/akka/remote/testconductor/Conductor.scala +++ b/akka-multi-node-testkit/src/main/scala/akka/remote/testconductor/Conductor.scala @@ -122,10 +122,17 @@ trait Conductor { this: TestConductorExt ⇒ def blackhole(node: RoleName, target: RoleName, direction: Direction): Future[Done] = throttle(node, target, direction, 0f) - private def requireTestConductorTranport(): Unit = - if (!transport.defaultAddress.protocol.contains(".trttl.gremlin.")) - throw new ConfigurationException("To use this feature you must activate the failure injector adapters " + - "(trttl, gremlin) by specifying `testTransport(on = true)` in your MultiNodeConfig.") + private def requireTestConductorTranport(): Unit = { + if (transport.provider.remoteSettings.EnableArtery) { + if (!transport.provider.remoteSettings.TestMode) + throw new ConfigurationException("To use this feature you must activate the test mode " + + "by specifying `testTransport(on = true)` in your MultiNodeConfig.") + } else { + if (!transport.defaultAddress.protocol.contains(".trttl.gremlin.")) + throw new ConfigurationException("To use this feature you must activate the failure injector adapters " + + "(trttl, gremlin) by specifying `testTransport(on = true)` in your MultiNodeConfig.") + } + } /** * Switch the Netty pipeline of the remote support into pass through mode for diff --git a/akka-multi-node-testkit/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala b/akka-multi-node-testkit/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala index bd0446bce2..cad8949dfa 100644 --- a/akka-multi-node-testkit/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala +++ b/akka-multi-node-testkit/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala @@ -99,6 +99,7 @@ abstract class MultiNodeConfig { if (_testTransport) ConfigFactory.parseString( """ akka.remote.netty.tcp.applied-adapters = [trttl, gremlin] + akka.remote.artery.advanced.test-mode = on """) else ConfigFactory.empty diff --git a/akka-remote/src/main/resources/reference.conf b/akka-remote/src/main/resources/reference.conf index 59da3085ef..77463125b4 100644 --- a/akka-remote/src/main/resources/reference.conf +++ b/akka-remote/src/main/resources/reference.conf @@ -107,6 +107,9 @@ akka { large-message-destinations = [] advanced { + # For enabling testing features, such as blackhole in akka-remote-testkit. + test-mode = off + # Settings for the materializer that is used for the remote streams. materializer = ${akka.stream.materializer} materializer { diff --git a/akka-remote/src/main/scala/akka/remote/RemoteSettings.scala b/akka-remote/src/main/scala/akka/remote/RemoteSettings.scala index dcd033b248..3eb9efdb80 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteSettings.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteSettings.scala @@ -29,6 +29,7 @@ final class RemoteSettings(val config: Config) { val EmbeddedMediaDriver = getBoolean("akka.remote.artery.advanced.embedded-media-driver") val AeronDirectoryName = getString("akka.remote.artery.advanced.aeron-dir") requiring (dir ⇒ EmbeddedMediaDriver || dir.nonEmpty, "aeron-dir must be defined when using external media driver") + val TestMode: Boolean = getBoolean("akka.remote.artery.advanced.test-mode") val LogReceive: Boolean = getBoolean("akka.remote.log-received-messages") diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala index ef33bd4cbb..cda46826a3 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala @@ -4,7 +4,12 @@ package akka.remote.artery import java.io.File -import java.nio.ByteOrder +import java.net.InetSocketAddress +import java.nio.channels.DatagramChannel +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.concurrent.Promise @@ -12,24 +17,24 @@ import scala.concurrent.duration._ import scala.util.Failure import scala.util.Success import scala.util.Try + import akka.Done import akka.NotUsed import akka.actor.ActorRef import akka.actor.Address +import akka.actor.Cancellable import akka.actor.ExtendedActorSystem import akka.actor.InternalActorRef -import akka.actor.Props import akka.event.Logging import akka.event.LoggingAdapter import akka.remote.AddressUidExtension import akka.remote.EndpointManager.Send import akka.remote.EventPublisher -import akka.remote.MessageSerializer import akka.remote.RemoteActorRef import akka.remote.RemoteActorRefProvider +import akka.remote.RemoteSettings import akka.remote.RemoteTransport import akka.remote.RemotingLifecycleEvent -import akka.remote.SeqNo import akka.remote.ThisActorSystemQuarantinedEvent import akka.remote.UniqueAddress import akka.remote.artery.InboundControlJunction.ControlMessageObserver @@ -37,22 +42,22 @@ import akka.remote.artery.InboundControlJunction.ControlMessageSubject import akka.remote.artery.OutboundControlJunction.OutboundControlIngress import akka.remote.transport.AkkaPduCodec import akka.remote.transport.AkkaPduProtobufCodec -import akka.serialization.Serialization import akka.stream.AbruptTerminationException import akka.stream.ActorMaterializer +import akka.stream.ActorMaterializerSettings import akka.stream.KillSwitches import akka.stream.Materializer import akka.stream.SharedKillSwitch import akka.stream.scaladsl.Flow -import akka.stream.scaladsl.Framing import akka.stream.scaladsl.Keep import akka.stream.scaladsl.Sink import akka.stream.scaladsl.Source -import akka.util.{ ByteString, ByteStringBuilder, WildcardTree } import akka.util.Helpers.ConfigOps import akka.util.Helpers.Requiring +import akka.util.WildcardTree import io.aeron.Aeron import io.aeron.AvailableImageHandler +import io.aeron.CncFileDescriptor import io.aeron.Image import io.aeron.UnavailableImageHandler import io.aeron.driver.MediaDriver @@ -71,6 +76,7 @@ import akka.actor.Cancellable import scala.collection.JavaConverters._ import akka.stream.ActorMaterializerSettings + /** * INTERNAL API */ @@ -216,7 +222,6 @@ private[akka] trait OutboundContext { */ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: RemoteActorRefProvider) extends RemoteTransport(_system, _provider) with InboundContext { - import provider.remoteSettings import FlightRecorderEvents._ // these vars are initialized once in the start method @@ -240,6 +245,8 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R private val killSwitch: SharedKillSwitch = KillSwitches.shared("transportKillSwitch") @volatile private[this] var _shutdown = false + private val testStages: CopyOnWriteArrayList[TestManagementApi] = new CopyOnWriteArrayList + // FIXME config private val systemMessageResendInterval: FiniteDuration = 1.second private val handshakeRetryInterval: FiniteDuration = 1.second @@ -283,6 +290,8 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R private val associationRegistry = new AssociationRegistry( remoteAddress ⇒ new Association(this, materializer, remoteAddress, controlSubject, largeMessageDestinations)) + def remoteSettings: RemoteSettings = provider.remoteSettings + override def start(): Unit = { startMediaDriver() startAeron() @@ -396,13 +405,24 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } private def runInboundControlStream(): Unit = { - val (c, completed) = Source.fromGraph( - new AeronSource(inboundChannel, controlStreamId, aeron, taskRunner, envelopePool, flightRecorder.createEventSink()) - ) - .viaMat(inboundControlFlow)(Keep.right) - .toMat(Sink.ignore)(Keep.both) - .run()(materializer) - controlSubject = c + val (ctrl, completed) = + if (remoteSettings.TestMode) { + val (mgmt, (ctrl, completed)) = + aeronSource(controlStreamId, envelopePool) + .via(inboundFlow) + .viaMat(inboundTestFlow)(Keep.right) + .toMat(inboundControlSink)(Keep.both) + .run()(materializer) + testStages.add(mgmt) + (ctrl, completed) + } else { + aeronSource(controlStreamId, envelopePool) + .via(inboundFlow) + .toMat(inboundControlSink)(Keep.right) + .run()(materializer) + } + + controlSubject = ctrl controlSubject.attach(new ControlMessageObserver { override def notify(inboundEnvelope: InboundEnvelope): Unit = { @@ -435,21 +455,46 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } private def runInboundOrdinaryMessagesStream(): Unit = { - val completed = Source.fromGraph( - new AeronSource(inboundChannel, ordinaryStreamId, aeron, taskRunner, envelopePool, flightRecorder.createEventSink()) - ) - .via(inboundFlow) - .runWith(Sink.ignore)(materializer) + val completed = + if (remoteSettings.TestMode) { + val (mgmt, c) = aeronSource(ordinaryStreamId, envelopePool) + .via(inboundFlow) + .viaMat(inboundTestFlow)(Keep.right) + .toMat(inboundSink)(Keep.both) + .run()(materializer) + testStages.add(mgmt) + c + } else { + aeronSource(ordinaryStreamId, envelopePool) + .via(inboundFlow) + .toMat(inboundSink)(Keep.right) + .run()(materializer) + } attachStreamRestart("Inbound message stream", completed, () ⇒ runInboundOrdinaryMessagesStream()) } private def runInboundLargeMessagesStream(): Unit = { - val completed = Source.fromGraph( - new AeronSource(inboundChannel, largeStreamId, aeron, taskRunner, largeEnvelopePool, flightRecorder.createEventSink() - )) + val completed = + if (remoteSettings.TestMode) { + val (mgmt, c) = aeronSource(largeStreamId, largeEnvelopePool) + .via(inboundLargeFlow) + .viaMat(inboundTestFlow)(Keep.right) + .toMat(inboundSink)(Keep.both) + .run()(materializer) + testStages.add(mgmt) + c + } else { + aeronSource(largeStreamId, largeEnvelopePool) + .via(inboundLargeFlow) + .toMat(inboundSink)(Keep.right) + .run()(materializer) + } + + aeronSource(largeStreamId, largeEnvelopePool) .via(inboundLargeFlow) - .runWith(Sink.ignore)(materializer) + .toMat(inboundSink)(Keep.right) + .run()(materializer) attachStreamRestart("Inbound large message stream", completed, () ⇒ runInboundLargeMessagesStream()) } @@ -502,6 +547,17 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R private[remote] def isShutdown(): Boolean = _shutdown + override def managementCommand(cmd: Any): Future[Boolean] = { + if (testStages.isEmpty) + Future.successful(false) + else { + import scala.collection.JavaConverters._ + import system.dispatcher + val allTestStages = testStages.asScala.toVector ++ associationRegistry.allAssociations.flatMap(_.testStages) + Future.sequence(allTestStages.map(_.send(cmd))).map(_ ⇒ true) + } + } + // InboundContext override def sendControl(to: Address, message: ControlMessage) = association(to).sendControl(message) @@ -572,6 +628,10 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R def encoder: Flow[Send, EnvelopeBuffer, NotUsed] = createEncoder(envelopePool) + def aeronSource(streamId: Int, pool: EnvelopeBufferPool): Source[EnvelopeBuffer, NotUsed] = + Source.fromGraph(new AeronSource(inboundChannel, streamId, aeron, taskRunner, pool, + flightRecorder.createEventSink())) + val messageDispatcherSink: Sink[InboundEnvelope, Future[Done]] = Sink.foreach[InboundEnvelope] { m ⇒ messageDispatcher.dispatch(m.recipient, m.recipientAddress, m.message, m.senderOption) } @@ -584,33 +644,31 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R def decoder: Flow[EnvelopeBuffer, InboundEnvelope, NotUsed] = createDecoder(envelopePool) - def inboundSink: Sink[InboundEnvelope, NotUsed] = + def inboundSink: Sink[InboundEnvelope, Future[Done]] = Flow[InboundEnvelope] .via(new InboundHandshake(this, inControlStream = false)) .via(new InboundQuarantineCheck(this)) - .to(messageDispatcherSink) + .toMat(messageDispatcherSink)(Keep.right) - def inboundFlow: Flow[EnvelopeBuffer, ByteString, NotUsed] = { - Flow.fromSinkAndSource( - decoder.to(inboundSink), - Source.maybe[ByteString].via(killSwitch.flow)) + def inboundFlow: Flow[EnvelopeBuffer, InboundEnvelope, NotUsed] = { + Flow[EnvelopeBuffer] + .via(killSwitch.flow) + .via(decoder) } - def inboundLargeFlow: Flow[EnvelopeBuffer, ByteString, NotUsed] = { - Flow.fromSinkAndSource( - createDecoder(largeEnvelopePool).to(inboundSink), - Source.maybe[ByteString].via(killSwitch.flow)) + def inboundLargeFlow: Flow[EnvelopeBuffer, InboundEnvelope, NotUsed] = { + Flow[EnvelopeBuffer] + .via(killSwitch.flow) + .via(createDecoder(largeEnvelopePool)) } - def inboundControlFlow: Flow[EnvelopeBuffer, ByteString, ControlMessageSubject] = { - Flow.fromSinkAndSourceMat( - decoder - .via(new InboundHandshake(this, inControlStream = true)) - .via(new InboundQuarantineCheck(this)) - .viaMat(new InboundControlJunction)(Keep.right) - .via(new SystemMessageAcker(this)) - .to(messageDispatcherSink), - Source.maybe[ByteString].via(killSwitch.flow))((a, b) ⇒ a) + def inboundControlSink: Sink[InboundEnvelope, (ControlMessageSubject, Future[Done])] = { + Flow[InboundEnvelope] + .via(new InboundHandshake(this, inControlStream = true)) + .via(new InboundQuarantineCheck(this)) + .viaMat(new InboundControlJunction)(Keep.right) + .via(new SystemMessageAcker(this)) + .toMat(messageDispatcherSink)(Keep.both) } private def initializeFlightRecorder(): (FileChannel, File, FlightRecorder) = { @@ -622,6 +680,12 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R (fileChannel, afrFile, new FlightRecorder(fileChannel)) } + def inboundTestFlow: Flow[InboundEnvelope, InboundEnvelope, TestManagementApi] = + Flow.fromGraph(new InboundTestStage(this)) + + def outboundTestFlow(association: Association): Flow[Send, Send, TestManagementApi] = + Flow.fromGraph(new OutboundTestStage(association)) + } /** diff --git a/akka-remote/src/main/scala/akka/remote/artery/Association.scala b/akka-remote/src/main/scala/akka/remote/artery/Association.scala index 38b1980a62..fefce1f329 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Association.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Association.scala @@ -6,6 +6,7 @@ package akka.remote.artery import java.util.Queue import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference @@ -61,7 +62,7 @@ private[remote] class Association( import Association._ private val log = Logging(transport.system, getClass.getName) - private val controlQueueSize = transport.provider.remoteSettings.SysMsgBufferSize + private val controlQueueSize = transport.remoteSettings.SysMsgBufferSize // FIXME config queue size, and it should perhaps also be possible to use some kind of LinkedQueue // such as agrona.ManyToOneConcurrentLinkedQueue or AbstractNodeQueue for less memory consumption private val queueSize = 3072 @@ -85,6 +86,13 @@ private[remote] class Association( @volatile private[this] var _outboundControlIngress: OutboundControlIngress = _ @volatile private[this] var materializing = new CountDownLatch(1) + private val _testStages: CopyOnWriteArrayList[TestManagementApi] = new CopyOnWriteArrayList + + def testStages(): List[TestManagementApi] = { + import scala.collection.JavaConverters._ + _testStages.asScala.toList + } + def outboundControlIngress: OutboundControlIngress = { if (_outboundControlIngress ne null) _outboundControlIngress @@ -268,9 +276,22 @@ private[remote] class Association( val wrapper = getOrCreateQueueWrapper(controlQueue, queueSize) controlQueue = wrapper // use new underlying queue immediately for restarts - val (queueValue, (control, completed)) = Source.fromGraph(new SendQueue[Send]) - .toMat(transport.outboundControl(this))(Keep.both) - .run()(materializer) + + val (queueValue, (control, completed)) = + if (transport.remoteSettings.TestMode) { + val ((queueValue, mgmt), (control, completed)) = + Source.fromGraph(new SendQueue[Send]) + .viaMat(transport.outboundTestFlow(this))(Keep.both) + .toMat(transport.outboundControl(this))(Keep.both) + .run()(materializer) + _testStages.add(mgmt) + (queueValue, (control, completed)) + } else { + Source.fromGraph(new SendQueue[Send]) + .toMat(transport.outboundControl(this))(Keep.both) + .run()(materializer) + } + queueValue.inject(wrapper.queue) // replace with the materialized value, still same underlying queue controlQueue = queueValue @@ -296,21 +317,46 @@ private[remote] class Association( private def runOutboundOrdinaryMessagesStream(): Unit = { val wrapper = getOrCreateQueueWrapper(queue, queueSize) queue = wrapper // use new underlying queue immediately for restarts - val (queueValue, completed) = Source.fromGraph(new SendQueue[Send]) - .toMat(transport.outbound(this))(Keep.both) - .run()(materializer) + + val (queueValue, completed) = + if (transport.remoteSettings.TestMode) { + val ((queueValue, mgmt), completed) = Source.fromGraph(new SendQueue[Send]) + .viaMat(transport.outboundTestFlow(this))(Keep.both) + .toMat(transport.outbound(this))(Keep.both) + .run()(materializer) + _testStages.add(mgmt) + (queueValue, completed) + } else { + Source.fromGraph(new SendQueue[Send]) + .toMat(transport.outbound(this))(Keep.both) + .run()(materializer) + } + queueValue.inject(wrapper.queue) // replace with the materialized value, still same underlying queue queue = queueValue + attachStreamRestart("Outbound message stream", completed, _ ⇒ runOutboundOrdinaryMessagesStream()) } private def runOutboundLargeMessagesStream(): Unit = { val wrapper = getOrCreateQueueWrapper(queue, largeQueueSize) largeQueue = wrapper // use new underlying queue immediately for restarts - val (queueValue, completed) = Source.fromGraph(new SendQueue[Send]) - .toMat(transport.outboundLarge(this))(Keep.both) - .run()(materializer) + + val (queueValue, completed) = + if (transport.remoteSettings.TestMode) { + val ((queueValue, mgmt), completed) = Source.fromGraph(new SendQueue[Send]) + .viaMat(transport.outboundTestFlow(this))(Keep.both) + .toMat(transport.outboundLarge(this))(Keep.both) + .run()(materializer) + _testStages.add(mgmt) + (queueValue, completed) + } else { + Source.fromGraph(new SendQueue[Send]) + .toMat(transport.outboundLarge(this))(Keep.both) + .run()(materializer) + } + queueValue.inject(wrapper.queue) // replace with the materialized value, still same underlying queue largeQueue = queueValue @@ -375,4 +421,7 @@ private[remote] class AssociationRegistry(createAssociation: Address ⇒ Associa throw new IllegalArgumentException(s"UID collision old [$previous] new [$a]") a } + + def allAssociations: Set[Association] = + associationsByAddress.get.values.toSet } diff --git a/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala b/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala new file mode 100644 index 0000000000..ec70bb69cf --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala @@ -0,0 +1,187 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ +package akka.remote.artery + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.concurrent.duration._ + +import akka.Done +import akka.actor.Address +import akka.remote.EndpointManager.Send +import akka.remote.transport.ThrottlerTransportAdapter.Blackhole +import akka.remote.transport.ThrottlerTransportAdapter.Direction +import akka.remote.transport.ThrottlerTransportAdapter.SetThrottle +import akka.remote.transport.ThrottlerTransportAdapter.Unthrottled +import akka.stream.Attributes +import akka.stream.FlowShape +import akka.stream.Inlet +import akka.stream.Outlet +import akka.stream.stage.AsyncCallback +import akka.stream.stage.CallbackWrapper +import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.stage.InHandler +import akka.stream.stage.OutHandler +import akka.stream.stage.TimerGraphStageLogic + +/** + * INTERNAL API + */ +private[remote] trait TestManagementApi { + def send(command: Any)(implicit ec: ExecutionContext): Future[Done] +} + +/** + * INTERNAL API + */ +private[remote] class TestManagementApiImpl(stopped: Future[Done], callback: AsyncCallback[TestManagementMessage]) + extends TestManagementApi { + + override def send(command: Any)(implicit ec: ExecutionContext): Future[Done] = { + if (stopped.isCompleted) + Future.successful(Done) + else { + val done = Promise[Done]() + callback.invoke(TestManagementMessage(command, done)) + Future.firstCompletedOf(List(done.future, stopped)) + } + } +} + +/** + * INTERNAL API + */ +private[remote] final case class TestManagementMessage(command: Any, done: Promise[Done]) + +/** + * INTERNAL API + */ +private[remote] class OutboundTestStage(outboundContext: OutboundContext) + extends GraphStageWithMaterializedValue[FlowShape[Send, Send], TestManagementApi] { + val in: Inlet[Send] = Inlet("OutboundTestStage.in") + val out: Outlet[Send] = Outlet("OutboundTestStage.out") + override val shape: FlowShape[Send, Send] = FlowShape(in, out) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val stoppedPromise = Promise[Done]() + + // FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way + val logic = new TimerGraphStageLogic(shape) with CallbackWrapper[TestManagementMessage] with InHandler with OutHandler with StageLogging { + + private var blackhole = Set.empty[Address] + + private val callback = getAsyncCallback[TestManagementMessage] { + case TestManagementMessage(command, done) ⇒ + command match { + case SetThrottle(address, Direction.Send | Direction.Both, Blackhole) ⇒ + log.info("blackhole outbound messages to {}", address) + blackhole += address + case SetThrottle(address, Direction.Send | Direction.Both, Unthrottled) ⇒ + log.info("accept outbound messages to {}", address) + blackhole -= address + case _ ⇒ // not interested + } + done.success(Done) + } + + override def preStart(): Unit = { + initCallback(callback.invoke) + } + + override def postStop(): Unit = stoppedPromise.success(Done) + + // InHandler + override def onPush(): Unit = { + val env = grab(in) + if (blackhole(outboundContext.remoteAddress)) { + log.debug( + "dropping outbound message [{}] to [{}] because of blackhole", + env.message.getClass.getName, outboundContext.remoteAddress) + pull(in) // drop message + } else + push(out, env) + } + + // OutHandler + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } + + val managementApi: TestManagementApi = new TestManagementApiImpl(stoppedPromise.future, logic) + + (logic, managementApi) + } + +} + +/** + * INTERNAL API + */ +private[remote] class InboundTestStage(inboundContext: InboundContext) + extends GraphStageWithMaterializedValue[FlowShape[InboundEnvelope, InboundEnvelope], TestManagementApi] { + val in: Inlet[InboundEnvelope] = Inlet("InboundTestStage.in") + val out: Outlet[InboundEnvelope] = Outlet("InboundTestStage.out") + override val shape: FlowShape[InboundEnvelope, InboundEnvelope] = FlowShape(in, out) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val stoppedPromise = Promise[Done]() + + // FIXME see issue #20503 related to CallbackWrapper, we might implement this in a better way + val logic = new TimerGraphStageLogic(shape) with CallbackWrapper[TestManagementMessage] with InHandler with OutHandler with StageLogging { + + private var blackhole = Set.empty[Address] + + private val callback = getAsyncCallback[TestManagementMessage] { + case TestManagementMessage(command, done) ⇒ + command match { + case SetThrottle(address, Direction.Receive | Direction.Both, Blackhole) ⇒ + log.info("blackhole inbound messages from {}", address) + blackhole += address + case SetThrottle(address, Direction.Receive | Direction.Both, Unthrottled) ⇒ + log.info("accept inbound messages from {}", address) + blackhole -= address + case _ ⇒ // not interested + } + done.success(Done) + } + + override def preStart(): Unit = { + initCallback(callback.invoke) + } + + override def postStop(): Unit = stoppedPromise.success(Done) + + // InHandler + override def onPush(): Unit = { + val env = grab(in) + inboundContext.association(env.originUid) match { + case null ⇒ + // unknown, handshake not completed + push(out, env) + case association ⇒ + if (blackhole(association.remoteAddress)) { + log.debug( + "dropping inbound message [{}] from [{}] with UID [{}] because of blackhole", + env.message.getClass.getName, association.remoteAddress, env.originUid) + pull(in) // drop message + } else + push(out, env) + } + } + + // OutHandler + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } + + val managementApi: TestManagementApi = new TestManagementApiImpl(stoppedPromise.future, logic) + + (logic, managementApi) + } + +} +