diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/RemoteQuarantinePiercingSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/RemoteQuarantinePiercingSpec.scala new file mode 100644 index 0000000000..92a8720d85 --- /dev/null +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/RemoteQuarantinePiercingSpec.scala @@ -0,0 +1,119 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ +package akka.remote + +import language.postfixOps +import scala.concurrent.duration._ +import com.typesafe.config.ConfigFactory +import akka.actor._ +import akka.remote.testconductor.RoleName +import akka.remote.transport.ThrottlerTransportAdapter.{ ForceDisassociate, Direction } +import akka.remote.testkit.MultiNodeConfig +import akka.remote.testkit.MultiNodeSpec +import akka.remote.testkit.STMultiNodeSpec +import akka.testkit._ +import akka.actor.ActorIdentity +import akka.remote.testconductor.RoleName +import akka.actor.Identify +import scala.concurrent.Await + +object RemoteQuarantinePiercingSpec extends MultiNodeConfig { + val first = role("first") + val second = role("second") + + commonConfig(debugConfig(on = false).withFallback( + ConfigFactory.parseString(""" + akka.loglevel = INFO + akka.remote.log-remote-lifecycle-events = INFO + akka.remote.quarantine-systems-for = 1 d + akka.remote.gate-invalid-addresses-for = 0.5 s + """))) + + class Subject extends Actor { + def receive = { + case "shutdown" ⇒ context.system.shutdown() + case "identify" ⇒ sender ! (AddressUidExtension(context.system).addressUid, self) + } + } + +} + +class RemoteQuarantinePiercingMultiJvmNode1 extends RemoteQuarantinePiercingSpec +class RemoteQuarantinePiercingMultiJvmNode2 extends RemoteQuarantinePiercingSpec + +abstract class RemoteQuarantinePiercingSpec extends MultiNodeSpec(RemoteQuarantinePiercingSpec) + with STMultiNodeSpec + with ImplicitSender { + + import RemoteQuarantinePiercingSpec._ + + override def initialParticipants = roles.size + + def identify(role: RoleName, actorName: String): (Int, ActorRef) = { + system.actorSelection(node(role) / "user" / actorName) ! "identify" + expectMsgType[(Int, ActorRef)] + } + + "RemoteNodeShutdownAndComesBack" must { + + "allow piercing through the quarantine when remote UID is new" taggedAs LongRunningTest in { + runOn(first) { + val secondAddress = node(second).address + enterBarrier("actors-started") + + // Acquire ActorRef from first system + val (uidFirst, subjectFirst) = identify(second, "subject") + enterBarrier("actor-identified") + + // Manually Quarantine the other system + RARP(system).provider.transport.quarantine(node(second).address, uidFirst) + + // Quarantine is up -- Cannot communicate with remote system any more + system.actorSelection(RootActorPath(secondAddress) / "user" / "subject") ! "identify" + expectNoMsg(2.seconds) + + // Shut down the other system -- which results in restart (see runOn(second)) + Await.result(testConductor.shutdown(second), 30.seconds) + + // Now wait until second system becomes alive again + within(30.seconds) { + // retry because the Subject actor might not be started yet + awaitAssert { + system.actorSelection(RootActorPath(secondAddress) / "user" / "subject") ! "identify" + val (uidSecond, subjectSecond) = expectMsgType[(Int, ActorRef)](1.second) + uidSecond must not be (uidFirst) + subjectSecond must not be (subjectFirst) + } + } + + // If we got here the Quarantine was successfully pierced since it is configured to last 1 day + + system.actorSelection(RootActorPath(secondAddress) / "user" / "subject") ! "shutdown" + + } + + runOn(second) { + val addr = system.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + system.actorOf(Props[Subject], "subject") + enterBarrier("actors-started") + + enterBarrier("actor-identified") + + system.awaitTermination(30.seconds) + + val freshSystem = ActorSystem(system.name, ConfigFactory.parseString(s""" + akka.remote.netty.tcp { + hostname = ${addr.host.get} + port = ${addr.port.get} + } + """).withFallback(system.settings.config)) + freshSystem.actorOf(Props[Subject], "subject") + + freshSystem.awaitTermination(30.seconds) + } + + } + + } +} diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala index 3d76ad36d3..013a9c84b0 100644 --- a/akka-remote/src/main/scala/akka/remote/Endpoint.scala +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -167,11 +167,12 @@ private[remote] object ReliableDeliverySupervisor { handleOrActive: Option[AkkaProtocolHandle], localAddress: Address, remoteAddress: Address, - transport: Transport, + refuseUid: Option[Int], + transport: AkkaProtocolTransport, settings: RemoteSettings, codec: AkkaPduCodec, receiveBuffers: ConcurrentHashMap[Link, ResendState]): Props = - Props(classOf[ReliableDeliverySupervisor], handleOrActive, localAddress, remoteAddress, transport, settings, + Props(classOf[ReliableDeliverySupervisor], handleOrActive, localAddress, remoteAddress, refuseUid, transport, settings, codec, receiveBuffers) } @@ -182,7 +183,8 @@ private[remote] class ReliableDeliverySupervisor( handleOrActive: Option[AkkaProtocolHandle], val localAddress: Address, val remoteAddress: Address, - val transport: Transport, + val refuseUid: Option[Int], + val transport: AkkaProtocolTransport, val settings: RemoteSettings, val codec: AkkaPduCodec, val receiveBuffers: ConcurrentHashMap[Link, ResendState]) extends Actor { @@ -386,6 +388,7 @@ private[remote] class ReliableDeliverySupervisor( handleOrActive = currentHandle, localAddress = localAddress, remoteAddress = remoteAddress, + refuseUid, transport = transport, settings = settings, AkkaPduProtobufCodec, @@ -427,12 +430,13 @@ private[remote] object EndpointWriter { handleOrActive: Option[AkkaProtocolHandle], localAddress: Address, remoteAddress: Address, - transport: Transport, + refuseUid: Option[Int], + transport: AkkaProtocolTransport, settings: RemoteSettings, codec: AkkaPduCodec, receiveBuffers: ConcurrentHashMap[Link, ResendState], reliableDeliverySupervisor: Option[ActorRef]): Props = - Props(classOf[EndpointWriter], handleOrActive, localAddress, remoteAddress, transport, settings, codec, + Props(classOf[EndpointWriter], handleOrActive, localAddress, remoteAddress, refuseUid, transport, settings, codec, receiveBuffers, reliableDeliverySupervisor) /** @@ -470,7 +474,8 @@ private[remote] class EndpointWriter( handleOrActive: Option[AkkaProtocolHandle], localAddress: Address, remoteAddress: Address, - transport: Transport, + refuseUid: Option[Int], + transport: AkkaProtocolTransport, settings: RemoteSettings, codec: AkkaPduCodec, val receiveBuffers: ConcurrentHashMap[Link, ResendState], @@ -532,7 +537,7 @@ private[remote] class EndpointWriter( reader = startReadEndpoint(h) Writing case None ⇒ - transport.associate(remoteAddress).mapTo[AkkaProtocolHandle].map(Handle(_)) pipeTo self + transport.associate(remoteAddress, refuseUid).map(Handle(_)) pipeTo self Initializing }, stateData = ()) diff --git a/akka-remote/src/main/scala/akka/remote/Remoting.scala b/akka-remote/src/main/scala/akka/remote/Remoting.scala index fb194e7a47..89ec8b6380 100644 --- a/akka-remote/src/main/scala/akka/remote/Remoting.scala +++ b/akka-remote/src/main/scala/akka/remote/Remoting.scala @@ -6,13 +6,11 @@ package akka.remote import akka.actor.SupervisorStrategy._ import akka.actor._ import akka.event.{ Logging, LoggingAdapter } -import akka.japi.Util.immutableSeq -import akka.pattern.{ AskTimeoutException, gracefulStop, pipe, ask } +import akka.pattern.{ gracefulStop, pipe, ask } import akka.remote.EndpointManager._ import akka.remote.Remoting.TransportSupervisor import akka.remote.transport.Transport.{ ActorAssociationEventListener, AssociationEventListener, InboundAssociation } import akka.remote.transport._ -import akka.util.Timeout import com.typesafe.config.Config import java.net.URLEncoder import java.util.concurrent.TimeoutException @@ -55,7 +53,7 @@ private[remote] object Remoting { final val EndpointManagerName = "endpointManager" - def localAddressForRemote(transportMapping: Map[String, Set[(Transport, Address)]], remote: Address): Address = { + def localAddressForRemote(transportMapping: Map[String, Set[(AkkaProtocolTransport, Address)]], remote: Address): Address = { transportMapping.get(remote.protocol) match { case Some(transports) ⇒ @@ -106,7 +104,7 @@ private[remote] object Remoting { private[remote] class Remoting(_system: ExtendedActorSystem, _provider: RemoteActorRefProvider) extends RemoteTransport(_system, _provider) { @volatile private var endpointManager: Option[ActorRef] = None - @volatile private var transportMapping: Map[String, Set[(Transport, Address)]] = _ + @volatile private var transportMapping: Map[String, Set[(AkkaProtocolTransport, Address)]] = _ // This is effectively a write-once variable similar to a lazy val. The reason for not using a lazy val is exception // handling. @volatile var addresses: Set[Address] = _ @@ -167,10 +165,10 @@ private[remote] class Remoting(_system: ExtendedActorSystem, _provider: RemoteAc endpointManager = Some(manager) try { - val addressesPromise: Promise[Seq[(Transport, Address)]] = Promise() + val addressesPromise: Promise[Seq[(AkkaProtocolTransport, Address)]] = Promise() manager ! Listen(addressesPromise) - val transports: Seq[(Transport, Address)] = Await.result(addressesPromise.future, + val transports: Seq[(AkkaProtocolTransport, Address)] = Await.result(addressesPromise.future, StartupTimeout.duration) if (transports.isEmpty) throw new RemoteTransportException("No transport drivers were loaded.", null) @@ -234,7 +232,7 @@ private[remote] object EndpointManager { // Messages between Remoting and EndpointManager sealed trait RemotingCommand extends NoSerializationVerificationNeeded - case class Listen(addressesPromise: Promise[Seq[(Transport, Address)]]) extends RemotingCommand + case class Listen(addressesPromise: Promise[Seq[(AkkaProtocolTransport, Address)]]) extends RemotingCommand case object StartupFinished extends RemotingCommand case object ShutdownAndFlush extends RemotingCommand case class Send(message: Any, senderOption: Option[ActorRef], recipient: RemoteActorRef, seqOpt: Option[SeqNo] = None) @@ -251,10 +249,10 @@ private[remote] object EndpointManager { // Messages internal to EndpointManager case object Prune extends NoSerializationVerificationNeeded - case class ListensResult(addressesPromise: Promise[Seq[(Transport, Address)]], - results: Seq[(Transport, Address, Promise[AssociationEventListener])]) + case class ListensResult(addressesPromise: Promise[Seq[(AkkaProtocolTransport, Address)]], + results: Seq[(AkkaProtocolTransport, Address, Promise[AssociationEventListener])]) extends NoSerializationVerificationNeeded - case class ListensFailure(addressesPromise: Promise[Seq[(Transport, Address)]], cause: Throwable) + case class ListensFailure(addressesPromise: Promise[Seq[(AkkaProtocolTransport, Address)]], cause: Throwable) extends NoSerializationVerificationNeeded // Helper class to store address pairs @@ -382,7 +380,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends // will be not part of this map! val endpoints = new EndpointRegistry // Mapping between transports and the local addresses they listen to - var transportMapping: Map[Address, Transport] = Map() + var transportMapping: Map[Address, AkkaProtocolTransport] = Map() def retryGateEnabled = settings.RetryGateClosedFor > Duration.Zero val pruneInterval: FiniteDuration = if (retryGateEnabled) settings.RetryGateClosedFor * 2 else Duration.Zero @@ -394,10 +392,10 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends override val supervisorStrategy = OneForOneStrategy(loggingEnabled = false) { - case InvalidAssociation(localAddress, remoteAddress, _) ⇒ + case e @ InvalidAssociation(localAddress, remoteAddress, reason) ⇒ log.warning("Tried to associate with unreachable remote address [{}]. " + - "Address is now gated for {} ms, all messages to this address will be delivered to dead letters.", - remoteAddress, settings.UnknownAddressGateClosedFor.toMillis) + "Address is now gated for {} ms, all messages to this address will be delivered to dead letters. Reason: {}", + remoteAddress, settings.UnknownAddressGateClosedFor.toMillis, reason.getMessage) endpoints.markAsFailed(sender, Deadline.now + settings.UnknownAddressGateClosedFor) context.system.eventStream.publish(AddressTerminated(remoteAddress)) Stop @@ -468,7 +466,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends case ia: InboundAssociation ⇒ context.system.scheduler.scheduleOnce(10.milliseconds, self, ia) case ManagementCommand(_) ⇒ - sender ! ManagementCommandAck(false) + sender ! ManagementCommandAck(status = false) case StartupFinished ⇒ context.become(accepting) case ShutdownAndFlush ⇒ @@ -504,25 +502,30 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends case s @ Send(message, senderOption, recipientRef, _) ⇒ val recipientAddress = recipientRef.path.address - def createAndRegisterWritingEndpoint(): ActorRef = endpoints.registerWritableEndpoint(recipientAddress, createEndpoint( - recipientAddress, - recipientRef.localAddressToUse, - transportMapping(recipientRef.localAddressToUse), - settings, - handleOption = None, - writing = true)) + def createAndRegisterWritingEndpoint(refuseUid: Option[Int]): ActorRef = + endpoints.registerWritableEndpoint( + recipientAddress, + createEndpoint( + recipientAddress, + recipientRef.localAddressToUse, + transportMapping(recipientRef.localAddressToUse), + settings, + handleOption = None, + writing = true, + refuseUid)) endpoints.writableEndpointWithPolicyFor(recipientAddress) match { case Some(Pass(endpoint)) ⇒ endpoint ! s case Some(Gated(timeOfRelease)) ⇒ - if (timeOfRelease.isOverdue()) createAndRegisterWritingEndpoint() ! s - else extendedSystem.deadLetters ! s - case Some(Quarantined(uid, timeOfRelease)) ⇒ - if (timeOfRelease.isOverdue()) createAndRegisterWritingEndpoint() ! s + if (timeOfRelease.isOverdue()) createAndRegisterWritingEndpoint(refuseUid = None) ! s else extendedSystem.deadLetters ! s + case Some(Quarantined(uid, _)) ⇒ + // timeOfRelease is only used for garbage collection reasons, therefore it is ignored here. We still have + // the Quarantined tombstone and we know what UID we don't want to accept, so use it. + createAndRegisterWritingEndpoint(refuseUid = Some(uid)) ! s case None ⇒ - createAndRegisterWritingEndpoint() ! s + createAndRegisterWritingEndpoint(refuseUid = None) ! s } @@ -541,14 +544,15 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends ep ! EndpointWriter.StopReading(ep) case _ ⇒ val writing = settings.UsePassiveConnections && !endpoints.hasWritableEndpointFor(handle.remoteAddress) - eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, inbound = true)) val endpoint = createEndpoint( handle.remoteAddress, handle.localAddress, transportMapping(handle.localAddress), settings, Some(handle), - writing) + writing, + refuseUid = None) if (writing) endpoints.registerWritableEndpoint(handle.remoteAddress, endpoint) else { @@ -576,7 +580,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends // Shutdown all endpoints and signal to sender when ready (and whether all endpoints were shut down gracefully) def shutdownAll[T](resources: TraversableOnce[T])(shutdown: T ⇒ Future[Boolean]): Future[Boolean] = { - (Future sequence resources.map(shutdown(_))) map { _.foldLeft(true) { _ && _ } } recover { + (Future sequence resources.map(shutdown)) map { _.forall(identity) } recover { case NonFatal(_) ⇒ false } } @@ -600,7 +604,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends case Terminated(_) ⇒ // why should we care now? } - private def listens: Future[Seq[(Transport, Address, Promise[AssociationEventListener])]] = { + private def listens: Future[Seq[(AkkaProtocolTransport, Address, Promise[AssociationEventListener])]] = { /* * Constructs chains of adapters on top of each driver as given in configuration. The resulting structure looks * like the following: @@ -619,9 +623,9 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends .createInstanceFor[Transport](fqn, args).recover({ case exception ⇒ throw new IllegalArgumentException( - (s"Cannot instantiate transport [$fqn]. " + + s"Cannot instantiate transport [$fqn]. " + "Make sure it extends [akka.remote.transport.Transport] and has constructor with " + - "[akka.actor.ExtendedActorSystem] and [com.typesafe.config.Config] parameters"), exception) + "[akka.actor.ExtendedActorSystem] and [com.typesafe.config.Config] parameters", exception) }).get @@ -629,7 +633,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends // The chain at this point: // Adapter <- ... <- Adapter <- Driver val wrappedTransport = - adapters.map { TransportAdaptersExtension.get(context.system).getAdapterProvider(_) }.foldLeft(driver) { + adapters.map { TransportAdaptersExtension.get(context.system).getAdapterProvider }.foldLeft(driver) { (t: Transport, provider: TransportAdapterProvider) ⇒ // The TransportAdapterProvider will wrap the given Transport and returns with a wrapped one provider.create(t, context.system.asInstanceOf[ExtendedActorSystem]) @@ -651,14 +655,15 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends if (pendingReadHandoffs.contains(takingOverFrom)) { val handle = pendingReadHandoffs(takingOverFrom) pendingReadHandoffs -= takingOverFrom - eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, inbound = true)) val endpoint = createEndpoint( handle.remoteAddress, handle.localAddress, transportMapping(handle.localAddress), settings, Some(handle), - writing = false) + writing = false, + refuseUid = None) endpoints.registerReadOnlyEndpoint(handle.remoteAddress, endpoint) } } @@ -670,16 +675,19 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends private def createEndpoint(remoteAddress: Address, localAddress: Address, - transport: Transport, + transport: AkkaProtocolTransport, endpointSettings: RemoteSettings, handleOption: Option[AkkaProtocolHandle], - writing: Boolean): ActorRef = { + writing: Boolean, + refuseUid: Option[Int]): ActorRef = { assert(transportMapping contains localAddress) + assert(writing || refuseUid.isEmpty) if (writing) context.watch(context.actorOf(RARP(extendedSystem).configureDispatcher(ReliableDeliverySupervisor.props( handleOption, localAddress, remoteAddress, + refuseUid, transport, endpointSettings, AkkaPduProtobufCodec, @@ -689,6 +697,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends handleOption, localAddress, remoteAddress, + refuseUid, transport, endpointSettings, AkkaPduProtobufCodec, diff --git a/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala b/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala index 3ef720b4ee..92fcd88e7f 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala @@ -49,6 +49,10 @@ private[remote] object AkkaProtocolTransport { //Couldn't these go into the Remo val AkkaOverhead: Int = 0 //Don't know yet val UniqueId = new java.util.concurrent.atomic.AtomicInteger(0) + case class AssociateUnderlyingRefuseUid( + remoteAddress: Address, + statusPromise: Promise[AssociationHandle], + refuseUid: Option[Int]) extends NoSerializationVerificationNeeded } case class HandshakeInfo(origin: Address, uid: Int, cookie: Option[String]) @@ -86,6 +90,15 @@ private[remote] class AkkaProtocolTransport( override def managementCommand(cmd: Any): Future[Boolean] = wrappedTransport.managementCommand(cmd) + def associate(remoteAddress: Address, refuseUid: Option[Int]): Future[AkkaProtocolHandle] = { + // Prepare a future, and pass its promise to the manager + val statusPromise: Promise[AssociationHandle] = Promise() + + manager ! AssociateUnderlyingRefuseUid(removeScheme(remoteAddress), statusPromise, refuseUid) + + statusPromise.future.mapTo[AkkaProtocolHandle] + } + override val maximumOverhead: Int = AkkaProtocolTransport.AkkaOverhead protected def managerName = s"akkaprotocolmanager.${wrappedTransport.schemeIdentifier}${UniqueId.getAndIncrement}" protected def managerProps = { @@ -115,27 +128,39 @@ private[transport] class AkkaProtocolManager( val stateActorAssociationHandler = associationListener val stateActorSettings = settings val failureDetector = createTransportFailureDetector() - context.actorOf(RARP(context.system).configureDispatcher(Props(classOf[ProtocolStateActor], + context.actorOf(RARP(context.system).configureDispatcher(ProtocolStateActor.inboundProps( HandshakeInfo(stateActorLocalAddress, AddressUidExtension(context.system).addressUid, stateActorSettings.SecureCookie), handle, stateActorAssociationHandler, stateActorSettings, AkkaPduProtobufCodec, - failureDetector)).withDeploy(Deploy.local), actorNameFor(handle.remoteAddress)) + failureDetector)), actorNameFor(handle.remoteAddress)) case AssociateUnderlying(remoteAddress, statusPromise) ⇒ - val stateActorLocalAddress = localAddress - val stateActorSettings = settings - val stateActorWrappedTransport = wrappedTransport - val failureDetector = createTransportFailureDetector() - context.actorOf(RARP(context.system).configureDispatcher(Props(classOf[ProtocolStateActor], - HandshakeInfo(stateActorLocalAddress, AddressUidExtension(context.system).addressUid, stateActorSettings.SecureCookie), - remoteAddress, - statusPromise, - stateActorWrappedTransport, - stateActorSettings, - AkkaPduProtobufCodec, - failureDetector)).withDeploy(Deploy.local), actorNameFor(remoteAddress)) + createOutboundStateActor(remoteAddress, statusPromise, None) + case AssociateUnderlyingRefuseUid(remoteAddress, statusPromise, refuseUid) ⇒ + createOutboundStateActor(remoteAddress, statusPromise, refuseUid) + + } + + private def createOutboundStateActor( + remoteAddress: Address, + statusPromise: Promise[AssociationHandle], + refuseUid: Option[Int]): Unit = { + + val stateActorLocalAddress = localAddress + val stateActorSettings = settings + val stateActorWrappedTransport = wrappedTransport + val failureDetector = createTransportFailureDetector() + context.actorOf(RARP(context.system).configureDispatcher(ProtocolStateActor.outboundProps( + HandshakeInfo(stateActorLocalAddress, AddressUidExtension(context.system).addressUid, stateActorSettings.SecureCookie), + remoteAddress, + statusPromise, + stateActorWrappedTransport, + stateActorSettings, + AkkaPduProtobufCodec, + failureDetector, + refuseUid)), actorNameFor(remoteAddress)) } private def createTransportFailureDetector(): FailureDetector = @@ -215,10 +240,34 @@ private[transport] object ProtocolStateActor { extends ProtocolStateData case object TimeoutReason + case object ForbiddenUidReason + + private[remote] def outboundProps( + handshakeInfo: HandshakeInfo, + remoteAddress: Address, + statusPromise: Promise[AssociationHandle], + transport: Transport, + settings: AkkaProtocolSettings, + codec: AkkaPduCodec, + failureDetector: FailureDetector, + refuseUid: Option[Int]): Props = + Props(classOf[ProtocolStateActor], handshakeInfo, remoteAddress, statusPromise, transport, settings, codec, + failureDetector, refuseUid).withDeploy(Deploy.local) + + private[remote] def inboundProps( + handshakeInfo: HandshakeInfo, + wrappedHandle: AssociationHandle, + associationListener: AssociationEventListener, + settings: AkkaProtocolSettings, + codec: AkkaPduCodec, + failureDetector: FailureDetector): Props = + Props(classOf[ProtocolStateActor], handshakeInfo, wrappedHandle, associationListener, settings, codec, + failureDetector).withDeploy(Deploy.local) } private[transport] class ProtocolStateActor(initialData: InitialProtocolStateData, private val localHandshakeInfo: HandshakeInfo, + private val refuseUid: Option[Int], private val settings: AkkaProtocolSettings, private val codec: AkkaPduCodec, private val failureDetector: FailureDetector) @@ -235,8 +284,9 @@ private[transport] class ProtocolStateActor(initialData: InitialProtocolStateDat transport: Transport, settings: AkkaProtocolSettings, codec: AkkaPduCodec, - failureDetector: FailureDetector) = { - this(OutboundUnassociated(remoteAddress, statusPromise, transport), handshakeInfo, settings, codec, failureDetector) + failureDetector: FailureDetector, + refuseUid: Option[Int]) = { + this(OutboundUnassociated(remoteAddress, statusPromise, transport), handshakeInfo, refuseUid, settings, codec, failureDetector) } // Inbound case @@ -246,7 +296,7 @@ private[transport] class ProtocolStateActor(initialData: InitialProtocolStateDat settings: AkkaProtocolSettings, codec: AkkaPduCodec, failureDetector: FailureDetector) = { - this(InboundUnassociated(associationListener, wrappedHandle), handshakeInfo, settings, codec, failureDetector) + this(InboundUnassociated(associationListener, wrappedHandle), handshakeInfo, refuseUid = None, settings, codec, failureDetector) } val localAddress = localHandshakeInfo.origin @@ -295,6 +345,10 @@ private[transport] class ProtocolStateActor(initialData: InitialProtocolStateDat case Event(InboundPayload(p), OutboundUnderlyingAssociated(statusPromise, wrappedHandle)) ⇒ decodePdu(p) match { + case Associate(handshakeInfo) if refuseUid.exists(_ == handshakeInfo.uid) ⇒ + sendDisassociate(wrappedHandle, Quarantined) + stop(FSM.Failure(ForbiddenUidReason)) + case Associate(handshakeInfo) ⇒ failureDetector.heartbeat() goto(Open) using AssociatedWaitHandler( @@ -427,9 +481,14 @@ private[transport] class ProtocolStateActor(initialData: InitialProtocolStateDat case StopEvent(reason, _, OutboundUnderlyingAssociated(statusPromise, wrappedHandle)) ⇒ statusPromise.tryFailure(reason match { - case FSM.Failure(TimeoutReason) ⇒ new AkkaProtocolException("No response from remote. Handshake timed out") - case FSM.Failure(info: DisassociateInfo) ⇒ disassociateException(info) - case _ ⇒ new AkkaProtocolException("Transport disassociated before handshake finished") + case FSM.Failure(TimeoutReason) ⇒ + new AkkaProtocolException("No response from remote. Handshake timed out") + case FSM.Failure(info: DisassociateInfo) ⇒ + disassociateException(info) + case FSM.Failure(ForbiddenUidReason) ⇒ + InvalidAssociationException("The remote system has a UID that has been quarantined. Association aborted.") + case _ ⇒ + new AkkaProtocolException("Transport disassociated before handshake finished") }) wrappedHandle.disassociate() @@ -468,6 +527,7 @@ private[transport] class ProtocolStateActor(initialData: InitialProtocolStateDat override protected def logTermination(reason: FSM.Reason): Unit = reason match { case FSM.Failure(TimeoutReason) ⇒ // no logging case FSM.Failure(_: DisassociateInfo) ⇒ // no logging + case FSM.Failure(ForbiddenUidReason) ⇒ // no logging case other ⇒ super.logTermination(reason) } diff --git a/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala b/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala index 9f5b175db8..d8ed939266 100644 --- a/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala @@ -124,13 +124,13 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re "register itself as reader on injecteted handles" in { val (failureDetector, _, _, handle) = collaborators - system.actorOf(Props(classOf[ProtocolStateActor], + system.actorOf(ProtocolStateActor.inboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), handle, ActorAssociationEventListener(testActor), new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector)) awaitCond(handle.readHandlerPromise.isCompleted) } @@ -138,13 +138,13 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re "in inbound mode accept payload after Associate PDU received" in { val (failureDetector, registry, _, handle) = collaborators - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.inboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), handle, ActorAssociationEventListener(testActor), new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector)) reader ! testAssociate(uid = 33, cookie = None) @@ -173,13 +173,13 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re "in inbound mode disassociate when an unexpected message arrives instead of Associate" in { val (failureDetector, registry, _, handle) = collaborators - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.inboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), handle, ActorAssociationEventListener(testActor), new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector)) // a stray message will force a disassociate reader ! testHeartbeat @@ -199,14 +199,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, 42, None)) failureDetector.called must be(true) @@ -233,13 +234,13 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re "ignore incoming associations with wrong cookie" in { val (failureDetector, registry, _, handle) = collaborators - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.inboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = Some("abcde")), handle, ActorAssociationEventListener(testActor), new AkkaProtocolSettings(ConfigFactory.parseString("akka.remote.require-cookie = on").withFallback(conf)), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector)) reader ! testAssociate(uid = 33, Some("xyzzy")) @@ -252,13 +253,13 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re "accept incoming associations with correct cookie" in { val (failureDetector, registry, _, handle) = collaborators - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.inboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = Some("abcde")), handle, ActorAssociationEventListener(testActor), new AkkaProtocolSettings(ConfigFactory.parseString("akka.remote.require-cookie = on").withFallback(conf)), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector)) // Send the correct cookie reader ! testAssociate(uid = 33, Some("abcde")) @@ -284,14 +285,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - system.actorOf(Props(classOf[ProtocolStateActor], + system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = Some("abcde")), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(ConfigFactory.parseString("akka.remote.require-cookie = on").withFallback(conf)), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, uid = 42, cookie = Some("abcde"))) } @@ -302,14 +304,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, uid = 42, cookie = None)) @@ -337,14 +340,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - val reader = system.actorOf(Props(classOf[ProtocolStateActor], + val reader = system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, uid = 42, cookie = None)) @@ -372,14 +376,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - val stateActor = system.actorOf(Props(classOf[ProtocolStateActor], + val stateActor = system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, uid = 42, cookie = None)) @@ -410,14 +415,15 @@ class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.Re val statusPromise: Promise[AssociationHandle] = Promise() - val stateActor = system.actorOf(Props(classOf[ProtocolStateActor], + val stateActor = system.actorOf(ProtocolStateActor.outboundProps( HandshakeInfo(origin = localAddress, uid = 42, cookie = None), remoteAddress, statusPromise, transport, new AkkaProtocolSettings(conf), codec, - failureDetector).withDeploy(Deploy.local)) + failureDetector, + refuseUid = None)) awaitCond(lastActivityIsAssociate(registry, uid = 42, cookie = None))