diff --git a/akka-cluster/src/main/scala/akka/cluster/Cluster.scala b/akka-cluster/src/main/scala/akka/cluster/Cluster.scala index 13f93d0482..592b04214d 100644 --- a/akka-cluster/src/main/scala/akka/cluster/Cluster.scala +++ b/akka-cluster/src/main/scala/akka/cluster/Cluster.scala @@ -61,7 +61,7 @@ class Cluster(val system: ExtendedActorSystem) extends Extension { import settings._ val selfAddress: Address = system.provider match { - case c: ClusterActorRefProvider ⇒ c.transport.address + case c: ClusterActorRefProvider ⇒ c.transport.addresses.head // FIXME: temporary workaround. See #2663 case other ⇒ throw new ConfigurationException( "ActorSystem [%s] needs to have a 'ClusterActorRefProvider' enabled in the configuration, currently uses [%s]". format(system, other.getClass.getName)) diff --git a/akka-cluster/src/test/scala/akka/cluster/ClusterSpec.scala b/akka-cluster/src/test/scala/akka/cluster/ClusterSpec.scala index a659abf313..0cb564a69b 100644 --- a/akka-cluster/src/test/scala/akka/cluster/ClusterSpec.scala +++ b/akka-cluster/src/test/scala/akka/cluster/ClusterSpec.scala @@ -38,7 +38,8 @@ object ClusterSpec { class ClusterSpec extends AkkaSpec(ClusterSpec.config) with ImplicitSender { import ClusterSpec._ - val selfAddress = system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[ClusterActorRefProvider].transport.address + // FIXME: temporary workaround. See #2663 + val selfAddress = system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[ClusterActorRefProvider].transport.addresses.head val cluster = Cluster(system) def clusterView = cluster.readView diff --git a/akka-docs/rst/java/code/docs/serialization/SerializationDocTestBase.java b/akka-docs/rst/java/code/docs/serialization/SerializationDocTestBase.java index 7fdb6420f1..48c0d2fc62 100644 --- a/akka-docs/rst/java/code/docs/serialization/SerializationDocTestBase.java +++ b/akka-docs/rst/java/code/docs/serialization/SerializationDocTestBase.java @@ -140,7 +140,7 @@ public class SerializationDocTestBase { public Address getAddress() { final ActorRefProvider provider = system.provider(); if (provider instanceof RemoteActorRefProvider) { - return ((RemoteActorRefProvider) provider).transport().address(); + return ((RemoteActorRefProvider) provider).transport().addresses().head(); } else { throw new UnsupportedOperationException("need RemoteActorRefProvider"); } diff --git a/akka-docs/rst/scala/code/docs/serialization/SerializationDocSpec.scala b/akka-docs/rst/scala/code/docs/serialization/SerializationDocSpec.scala index d979952887..af49e48999 100644 --- a/akka-docs/rst/scala/code/docs/serialization/SerializationDocSpec.scala +++ b/akka-docs/rst/scala/code/docs/serialization/SerializationDocSpec.scala @@ -216,7 +216,7 @@ package docs.serialization { object ExternalAddress extends ExtensionKey[ExternalAddressExt] class ExternalAddressExt(system: ExtendedActorSystem) extends Extension { - def addressForAkka: Address = akka.transportOf(system).address + def addressForAkka: Address = akka.transportOf(system).addresses.head } def serializeAkkaDefault(ref: ActorRef): String = diff --git a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala index 07a6c9c22f..aea3ee8cd9 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala @@ -60,7 +60,7 @@ class TestConductorExt(val system: ExtendedActorSystem) extends Extension with C /** * Transport address of this Netty-like remote transport. */ - val address = transport.address + val address = transport.addresses.head //FIXME: Workaround for old-remoting -- must be removed later /** * INTERNAL API. diff --git a/akka-remote-tests/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala b/akka-remote-tests/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala index 9d5fd4b55e..78808f6b8c 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testkit/MultiNodeSpec.scala @@ -409,7 +409,8 @@ abstract class MultiNodeSpec(val myself: RoleName, _system: ActorSystem, _roles: // useful to see which jvm is running which role, used by LogRoleReplace utility log.info("Role [{}] started with address [{}]", myself.name, - system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.address) + //FIXME: Workaround for old-remoting -- must be removed later + system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.addresses.head) } diff --git a/akka-remote/src/main/java/akka/remote/RemoteProtocol.java b/akka-remote/src/main/java/akka/remote/RemoteProtocol.java index 204a68fca5..5d19c00a06 100644 --- a/akka-remote/src/main/java/akka/remote/RemoteProtocol.java +++ b/akka-remote/src/main/java/akka/remote/RemoteProtocol.java @@ -21,7 +21,7 @@ public final class RemoteProtocol { public final int getNumber() { return value; } - + public static CommandType valueOf(int value) { switch (value) { case 1: return CONNECT; @@ -3636,6 +3636,10 @@ public final class RemoteProtocol { // required uint32 port = 3; boolean hasPort(); int getPort(); + + // optional string protocol = 4; + boolean hasProtocol(); + String getProtocol(); } public static final class AddressProtocol extends com.google.protobuf.GeneratedMessage @@ -3740,10 +3744,43 @@ public final class RemoteProtocol { return port_; } + // optional string protocol = 4; + public static final int PROTOCOL_FIELD_NUMBER = 4; + private java.lang.Object protocol_; + public boolean hasProtocol() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + public String getProtocol() { + java.lang.Object ref = protocol_; + if (ref instanceof String) { + return (String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + String s = bs.toStringUtf8(); + if (com.google.protobuf.Internal.isValidUtf8(bs)) { + protocol_ = s; + } + return s; + } + } + private com.google.protobuf.ByteString getProtocolBytes() { + java.lang.Object ref = protocol_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8((String) ref); + protocol_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private void initFields() { system_ = ""; hostname_ = ""; port_ = 0; + protocol_ = ""; } private byte memoizedIsInitialized = -1; public final boolean isInitialized() { @@ -3778,6 +3815,9 @@ public final class RemoteProtocol { if (((bitField0_ & 0x00000004) == 0x00000004)) { output.writeUInt32(3, port_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + output.writeBytes(4, getProtocolBytes()); + } getUnknownFields().writeTo(output); } @@ -3799,6 +3839,10 @@ public final class RemoteProtocol { size += com.google.protobuf.CodedOutputStream .computeUInt32Size(3, port_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(4, getProtocolBytes()); + } size += getUnknownFields().getSerializedSize(); memoizedSerializedSize = size; return size; @@ -3929,6 +3973,8 @@ public final class RemoteProtocol { bitField0_ = (bitField0_ & ~0x00000002); port_ = 0; bitField0_ = (bitField0_ & ~0x00000004); + protocol_ = ""; + bitField0_ = (bitField0_ & ~0x00000008); return this; } @@ -3979,6 +4025,10 @@ public final class RemoteProtocol { to_bitField0_ |= 0x00000004; } result.port_ = port_; + if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + to_bitField0_ |= 0x00000008; + } + result.protocol_ = protocol_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -4004,6 +4054,9 @@ public final class RemoteProtocol { if (other.hasPort()) { setPort(other.getPort()); } + if (other.hasProtocol()) { + setProtocol(other.getProtocol()); + } this.mergeUnknownFields(other.getUnknownFields()); return this; } @@ -4062,6 +4115,11 @@ public final class RemoteProtocol { port_ = input.readUInt32(); break; } + case 34: { + bitField0_ |= 0x00000008; + protocol_ = input.readBytes(); + break; + } } } } @@ -4161,6 +4219,42 @@ public final class RemoteProtocol { return this; } + // optional string protocol = 4; + private java.lang.Object protocol_ = ""; + public boolean hasProtocol() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + public String getProtocol() { + java.lang.Object ref = protocol_; + if (!(ref instanceof String)) { + String s = ((com.google.protobuf.ByteString) ref).toStringUtf8(); + protocol_ = s; + return s; + } else { + return (String) ref; + } + } + public Builder setProtocol(String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + protocol_ = value; + onChanged(); + return this; + } + public Builder clearProtocol() { + bitField0_ = (bitField0_ & ~0x00000008); + protocol_ = getDefaultInstance().getProtocol(); + onChanged(); + return this; + } + void setProtocol(com.google.protobuf.ByteString value) { + bitField0_ |= 0x00000008; + protocol_ = value; + onChanged(); + } + // @@protoc_insertion_point(builder_scope:AddressProtocol) } @@ -6424,20 +6518,20 @@ public final class RemoteProtocol { "path\030\001 \002(\t\"Q\n\017MessageProtocol\022\017\n\007message" + "\030\001 \002(\014\022\024\n\014serializerId\030\002 \002(\005\022\027\n\017messageM" + "anifest\030\003 \001(\014\"3\n\025MetadataEntryProtocol\022\013" + - "\n\003key\030\001 \002(\t\022\r\n\005value\030\002 \002(\014\"A\n\017AddressPro" + + "\n\003key\030\001 \002(\t\022\r\n\005value\030\002 \002(\014\"S\n\017AddressPro" + "tocol\022\016\n\006system\030\001 \002(\t\022\020\n\010hostname\030\002 \002(\t\022" + - "\014\n\004port\030\003 \002(\r\"\216\001\n\027DaemonMsgCreateProtoco" + - "l\022\035\n\005props\030\001 \002(\0132\016.PropsProtocol\022\037\n\006depl" + - "oy\030\002 \002(\0132\017.DeployProtocol\022\014\n\004path\030\003 \002(\t\022" + - "%\n\nsupervisor\030\004 \002(\0132\021.ActorRefProtocol\"\205", - "\001\n\rPropsProtocol\022\022\n\ndispatcher\030\001 \002(\t\022\037\n\006" + - "deploy\030\002 \002(\0132\017.DeployProtocol\022\030\n\020fromCla" + - "ssCreator\030\003 \001(\t\022\017\n\007creator\030\004 \001(\014\022\024\n\014rout" + - "erConfig\030\005 \001(\014\"S\n\016DeployProtocol\022\014\n\004path" + - "\030\001 \002(\t\022\016\n\006config\030\002 \001(\014\022\024\n\014routerConfig\030\003" + - " \001(\014\022\r\n\005scope\030\004 \001(\014*7\n\013CommandType\022\013\n\007CO" + - "NNECT\020\001\022\014\n\010SHUTDOWN\020\002\022\r\n\tHEARTBEAT\020\003B\017\n\013" + - "akka.remoteH\001" + "\014\n\004port\030\003 \002(\r\022\020\n\010protocol\030\004 \001(\t\"\216\001\n\027Daem" + + "onMsgCreateProtocol\022\035\n\005props\030\001 \002(\0132\016.Pro" + + "psProtocol\022\037\n\006deploy\030\002 \002(\0132\017.DeployProto" + + "col\022\014\n\004path\030\003 \002(\t\022%\n\nsupervisor\030\004 \002(\0132\021.", + "ActorRefProtocol\"\205\001\n\rPropsProtocol\022\022\n\ndi" + + "spatcher\030\001 \002(\t\022\037\n\006deploy\030\002 \002(\0132\017.DeployP" + + "rotocol\022\030\n\020fromClassCreator\030\003 \001(\t\022\017\n\007cre" + + "ator\030\004 \001(\014\022\024\n\014routerConfig\030\005 \001(\014\"S\n\016Depl" + + "oyProtocol\022\014\n\004path\030\001 \002(\t\022\016\n\006config\030\002 \001(\014" + + "\022\024\n\014routerConfig\030\003 \001(\014\022\r\n\005scope\030\004 \001(\014*7\n" + + "\013CommandType\022\013\n\007CONNECT\020\001\022\014\n\010SHUTDOWN\020\002\022" + + "\r\n\tHEARTBEAT\020\003B\017\n\013akka.remoteH\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { @@ -6497,7 +6591,7 @@ public final class RemoteProtocol { internal_static_AddressProtocol_fieldAccessorTable = new com.google.protobuf.GeneratedMessage.FieldAccessorTable( internal_static_AddressProtocol_descriptor, - new java.lang.String[] { "System", "Hostname", "Port", }, + new java.lang.String[] { "System", "Hostname", "Port", "Protocol", }, akka.remote.RemoteProtocol.AddressProtocol.class, akka.remote.RemoteProtocol.AddressProtocol.Builder.class); internal_static_DaemonMsgCreateProtocol_descriptor = diff --git a/akka-remote/src/main/protocol/RemoteProtocol.proto b/akka-remote/src/main/protocol/RemoteProtocol.proto index ddcfe26d1d..438a2a1e87 100644 --- a/akka-remote/src/main/protocol/RemoteProtocol.proto +++ b/akka-remote/src/main/protocol/RemoteProtocol.proto @@ -78,6 +78,7 @@ message AddressProtocol { required string system = 1; required string hostname = 2; required uint32 port = 3; + optional string protocol = 4; } /** diff --git a/akka-remote/src/main/resources/reference.conf b/akka-remote/src/main/resources/reference.conf index a70106a8b2..2a73beb844 100644 --- a/akka-remote/src/main/resources/reference.conf +++ b/akka-remote/src/main/resources/reference.conf @@ -53,6 +53,60 @@ akka { } } + remoting { + + # FIXME document + failure-detector { + threshold = 7.0 + max-sample-size = 100 + min-std-deviation = 100 ms + acceptable-heartbeat-pause = 3 s + } + + # FIXME document + writer-dispatcher { + mailbox-type = "akka.dispatch.UnboundedDequeBasedMailbox" + } + + # If this is "on", Akka will log all RemoteLifeCycleEvents at the level + # defined for each, if off then they are not logged. Failures to deserialize + # received messages also fall under this flag. + log-remote-lifecycle-events = off + + # FIXME document + heartbeat-interval = 1 s + + # FIXME document + wait-activity-enabled = on + + # FIXME document + backoff-interval = 1 s + + # FIXME document + secure-cookie = "" + + # FIXME document + require-cookie = off + + # FIXME document + shutdown-timeout = 5 s + + # FIXME document + startup-timeout = 5 s + + # FIXME document + retry-latch-closed-for = 0 s + + # FIXME document + retry-window = 3 s + + # FIXME document + maximum-retries-in-window = 5 + + # FIXME document + use-passive-connections = on + } + remote { # Which implementation of akka.remote.RemoteTransport to use diff --git a/akka-remote/src/main/scala/akka/remote/DefaultFailureDetectorRegistry.scala b/akka-remote/src/main/scala/akka/remote/DefaultFailureDetectorRegistry.scala new file mode 100644 index 0000000000..2b3a6150e4 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/DefaultFailureDetectorRegistry.scala @@ -0,0 +1,84 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ + +package akka.remote + +import akka.event.LoggingAdapter +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.tailrec +import scala.collection.immutable.Map + +/** + * A lock-less thread-safe implementation of [[akka.remote.FailureDetectorRegistry]]. + * + * @param detectorFactory + * By-name parameter that returns the failure detector instance to be used by a newly registered resource + * + */ +class DefaultFailureDetectorRegistry[A](val detectorFactory: () ⇒ FailureDetector) extends FailureDetectorRegistry[A] { + + private val table = new AtomicReference[Map[A, FailureDetector]](Map()) + + /** + * Returns true if the resource is considered to be up and healthy and returns false otherwise. For unregistered + * resources it returns true. + */ + final override def isAvailable(resource: A): Boolean = table.get.get(resource) match { + case Some(r) ⇒ r.isAvailable + case _ ⇒ true + } + + final override def heartbeat(resource: A): Unit = { + + // Second option parameter is there to avoid the unnecessary creation of failure detectors when a CAS loop happens + // Note, _one_ unnecessary detector might be created -- but no more. + @tailrec + def doHeartbeat(resource: A, detector: Option[FailureDetector]): Unit = { + val oldTable = table.get + + oldTable.get(resource) match { + case Some(failureDetector) ⇒ failureDetector.heartbeat() + case None ⇒ + val newDetector = detector getOrElse detectorFactory() + val newTable = oldTable + (resource -> newDetector) + if (!table.compareAndSet(oldTable, newTable)) + doHeartbeat(resource, Some(newDetector)) + else + newDetector.heartbeat() + } + } + + doHeartbeat(resource, None) + } + + final override def remove(resource: A): Unit = { + + @tailrec + def doRemove(resource: A): Unit = { + val oldTable = table.get + + if (oldTable.contains(resource)) { + val newTable = oldTable - resource + + // if we won the race then update else try again + if (!table.compareAndSet(oldTable, newTable)) doRemove(resource) // recur + } + } + + doRemove(resource) + } + + final override def reset(): Unit = { + + @tailrec + def doReset(): Unit = { + val oldTable = table.get + // if we won the race then update else try again + if (!table.compareAndSet(oldTable, Map.empty[A, FailureDetector])) doReset() // recur + } + + doReset() + } +} + diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala new file mode 100644 index 0000000000..9b72507593 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -0,0 +1,246 @@ +package akka.remote + +import akka.AkkaException +import akka.actor._ +import akka.dispatch.SystemMessage +import akka.event.LoggingAdapter +import akka.pattern.pipe +import akka.remote.EndpointManager.Send +import akka.remote.RemoteProtocol.MessageProtocol +import akka.remote.transport.AkkaPduCodec._ +import akka.remote.transport.AssociationHandle._ +import akka.remote.transport.{ AkkaPduCodec, Transport, AssociationHandle } +import akka.serialization.Serialization +import akka.util.ByteString +import java.net.URLEncoder +import scala.util.control.NonFatal +import akka.actor.SupervisorStrategy.{ Restart, Stop } + +trait InboundMessageDispatcher { + def dispatch(recipient: InternalActorRef, + recipientAddress: Address, + serializedMessage: MessageProtocol, + senderOption: Option[ActorRef]): Unit +} + +class DefaultMessageDispatcher(private val system: ExtendedActorSystem, + private val provider: RemoteActorRefProvider, + private val log: LoggingAdapter) extends InboundMessageDispatcher { + + private val remoteDaemon = provider.remoteDaemon + + override def dispatch(recipient: InternalActorRef, + recipientAddress: Address, + serializedMessage: MessageProtocol, + senderOption: Option[ActorRef]): Unit = { + + import provider.remoteSettings._ + + lazy val payload: AnyRef = MessageSerializer.deserialize(system, serializedMessage) + lazy val payloadClass: Class[_] = if (payload eq null) null else payload.getClass + val sender: ActorRef = senderOption.getOrElse(system.deadLetters) + val originalReceiver = recipient.path + + lazy val msgLog = "RemoteMessage: " + payload + " to " + recipient + "<+{" + originalReceiver + "} from " + sender + + recipient match { + + case `remoteDaemon` ⇒ + if (LogReceive) log.debug("received daemon message {}", msgLog) + payload match { + case m @ (_: DaemonMsg | _: Terminated) ⇒ + try remoteDaemon ! m catch { + case NonFatal(e) ⇒ log.error(e, "exception while processing remote command {} from {}", m, sender) + } + case x ⇒ log.debug("remoteDaemon received illegal message {} from {}", x, sender) + } + + case l @ (_: LocalRef | _: RepointableRef) if l.isLocal ⇒ + if (LogReceive) log.debug("received local message {}", msgLog) + payload match { + case msg: PossiblyHarmful if UntrustedMode ⇒ + log.debug("operating in UntrustedMode, dropping inbound PossiblyHarmful message of type {}", msg.getClass) + case msg: SystemMessage ⇒ l.sendSystemMessage(msg) + case msg ⇒ l.!(msg)(sender) + } + + case r @ (_: RemoteRef | _: RepointableRef) if !r.isLocal && !UntrustedMode ⇒ + if (LogReceive) log.debug("received remote-destined message {}", msgLog) + if (provider.transport.addresses(recipientAddress)) + // if it was originally addressed to us but is in fact remote from our point of view (i.e. remote-deployed) + r.!(payload)(sender) + else + log.error("dropping message {} for non-local recipient {} arriving at {} inbound addresses are {}", + payloadClass, r, recipientAddress, provider.transport.addresses) + + case r ⇒ log.error("dropping message {} for unknown recipient {} arriving at {} inbound addresses are {}", + payloadClass, r, recipientAddress, provider.transport.addresses) + + } + } + +} + +object EndpointWriter { + + case object BackoffTimer + + sealed trait State + case object Initializing extends State + case object Buffering extends State + case object Writing extends State +} + +class EndpointException(msg: String, cause: Throwable) extends AkkaException(msg, cause) +case class InvalidAssociation(localAddress: Address, remoteAddress: Address, cause: Throwable) + extends EndpointException("Invalid address: " + remoteAddress, cause) + +private[remote] class EndpointWriter( + handleOrActive: Option[AssociationHandle], + val localAddress: Address, + val remoteAddress: Address, + val transport: Transport, + val settings: RemotingSettings, + val codec: AkkaPduCodec) extends Actor with Stash with FSM[EndpointWriter.State, Unit] { + + import EndpointWriter._ + import context.dispatcher + + val extendedSystem: ExtendedActorSystem = context.system.asInstanceOf[ExtendedActorSystem] + var reader: ActorRef = null + var handle: AssociationHandle = handleOrActive.getOrElse(null) + var inbound = false + val eventPublisher = new EventPublisher(context.system, log, settings.LogLifecycleEvents) + + override val supervisorStrategy = OneForOneStrategy() { + case NonFatal(e) ⇒ + publishAndThrow(e) + Stop + } + + val msgDispatch = + new DefaultMessageDispatcher(extendedSystem, extendedSystem.provider.asInstanceOf[RemoteActorRefProvider], log) + + private def publishAndThrow(reason: Throwable): Nothing = { + eventPublisher.notifyListeners(AssociationErrorEvent(reason, localAddress, remoteAddress, inbound)) + throw reason + } + + private def publishAndThrow(message: String, cause: Throwable): Nothing = + publishAndThrow(new EndpointException(message, cause)) + + override def postRestart(reason: Throwable): Unit = { + handle = null // Wipe out the possibly injected handle + preStart() + } + + override def preStart(): Unit = { + if (handle eq null) { + transport.associate(remoteAddress) pipeTo self + inbound = false + startWith(Initializing, ()) + } else { + startReadEndpoint() + inbound = true + startWith(Writing, ()) + } + } + + when(Initializing) { + case Event(Send(msg, senderOption, recipient), _) ⇒ + stash() + stay + case Event(Transport.Invalid(e), _) ⇒ + log.error(e, "Tried to associate with invalid remote address " + remoteAddress + + ". Address is now quarantined, all messages to this address will be delivered to dead letters.") + publishAndThrow(new InvalidAssociation(localAddress, remoteAddress, e)) + + case Event(Transport.Fail(e), _) ⇒ publishAndThrow(s"Association failed with $remoteAddress", e) + case Event(Transport.Ready(inboundHandle), _) ⇒ + handle = inboundHandle + startReadEndpoint() + goto(Writing) + + } + + when(Buffering) { + case Event(Send(msg, senderOption, recipient), _) ⇒ + stash() + stay + + case Event(BackoffTimer, _) ⇒ goto(Writing) + } + + when(Writing) { + case Event(Send(msg, senderOption, recipient), _) ⇒ + val pdu = codec.constructMessagePdu(recipient.localAddressToUse, recipient, serializeMessage(msg), senderOption) + val success = try handle.write(pdu) catch { + case NonFatal(e) ⇒ publishAndThrow("Failed to write message to the transport", e) + } + if (success) stay else { + stash + goto(Buffering) + } + } + + whenUnhandled { + case Event(Terminated(r), _) if r == reader ⇒ stop() + } + + onTransition { + case Initializing -> Writing ⇒ + unstashAll() + eventPublisher.notifyListeners(AssociatedEvent(localAddress, remoteAddress, inbound)) + case Writing -> Buffering ⇒ setTimer("backoff-timer", BackoffTimer, settings.BackoffPeriod, false) + case Buffering -> Writing ⇒ + unstashAll() + cancelTimer("backoff-timer") + } + + onTermination { + case StopEvent(_, _, _) ⇒ if (handle ne null) { + handle.disassociate() + eventPublisher.notifyListeners(DisassociatedEvent(localAddress, remoteAddress, inbound)) + } + } + + private def startReadEndpoint(): Unit = { + reader = context.actorOf(Props(new EndpointReader(codec, msgDispatch)), + "endpointReader-" + URLEncoder.encode(remoteAddress.toString, "utf-8")) + handle.readHandlerPromise.success(reader) + context.watch(reader) + } + + private def serializeMessage(msg: Any): MessageProtocol = { + Serialization.currentTransportAddress.withValue(handle.localAddress) { + (MessageSerializer.serialize(extendedSystem, msg.asInstanceOf[AnyRef])) + } + } + +} + +private[remote] class EndpointReader( + val codec: AkkaPduCodec, + val msgDispatch: InboundMessageDispatcher) extends Actor { + + val provider = context.system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider] + + override def receive: Receive = { + case Disassociated ⇒ context.stop(self) + + // FIXME: Do 2 step deserialization (old-remoting must be removed first) + case InboundPayload(p) ⇒ decodePdu(p) match { + + case Message(recipient, recipientAddress, serializedMessage, senderOption) ⇒ + msgDispatch.dispatch(recipient, recipientAddress, serializedMessage, senderOption) + + case _ ⇒ + } + } + + private def decodePdu(pdu: ByteString): AkkaPdu = try { + codec.decodePdu(pdu, provider) + } catch { + case NonFatal(e) ⇒ throw new EndpointException("Error while decoding incoming Akka PDU", e) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/FailureDetector.scala b/akka-remote/src/main/scala/akka/remote/FailureDetector.scala new file mode 100644 index 0000000000..1eaaaf7258 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/FailureDetector.scala @@ -0,0 +1,35 @@ +package akka.remote + +import java.util.concurrent.TimeUnit._ + +/** + * A failure detector is a thread-safe mutable construct that registers heartbeat events of a resource and is able to + * decide the availability of that monitored resource. + */ +trait FailureDetector { + + /** + * Returns true if the resource is considered to be up and healthy and returns false otherwise. + */ + def isAvailable: Boolean + + /** + * Notifies the FailureDetector that a heartbeat arrived from the monitored resource. This causes the FailureDetector + * to update its state. + */ + def heartbeat(): Unit + +} + +object FailureDetector { + + /** + * Abstraction of a clock that returns time in milliseconds. Clock can only be used to measure elapsed + * time and is not related to any other notion of system or wall-clock time. + */ + trait Clock extends (() ⇒ Long) + + implicit val defaultClock = new Clock { + def apply() = NANOSECONDS.toMillis(System.nanoTime) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/FailureDetectorRegistry.scala b/akka-remote/src/main/scala/akka/remote/FailureDetectorRegistry.scala new file mode 100644 index 0000000000..c5eb2f54b2 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/FailureDetectorRegistry.scala @@ -0,0 +1,36 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ + +package akka.remote + +/** + * Interface for a registry of Akka failure detectors. New resources are implicitly registered when heartbeat is first + * called with the resource given as parameter. + * + * @tparam A + * The type of the key that identifies a resource to be monitored by a failure detector + */ +trait FailureDetectorRegistry[A] { + + /** + * Returns true if the resource is considered to be up and healthy and returns false otherwise. + */ + def isAvailable(resource: A): Boolean + + /** + * Records a heartbeat for a resource. If the resource is not yet registered (i.e. this is the first heartbeat) then + * it is automatially registered. + */ + def heartbeat(resource: A): Unit + + /** + * Removes the heartbeat management for a resource. + */ + def remove(resource: A): Unit + + /** + * Removes all resources and starts over. + */ + def reset(): Unit +} diff --git a/akka-remote/src/main/scala/akka/remote/PhiAccrualFailureDetector.scala b/akka-remote/src/main/scala/akka/remote/PhiAccrualFailureDetector.scala new file mode 100644 index 0000000000..f0252106c7 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/PhiAccrualFailureDetector.scala @@ -0,0 +1,207 @@ +package akka.remote + +import akka.remote.FailureDetector.Clock +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.tailrec +import scala.concurrent.duration.FiniteDuration + +/** + * Implementation of 'The Phi Accrual Failure Detector' by Hayashibara et al. as defined in their paper: + * [http://ddg.jaist.ac.jp/pub/HDY+04.pdf] + * + * The suspicion level of failure is given by a value called φ (phi). + * The basic idea of the φ failure detector is to express the value of φ on a scale that + * is dynamically adjusted to reflect current network conditions. A configurable + * threshold is used to decide if φ is considered to be a failure. + * + * The value of φ is calculated as: + * + * {{{ + * φ = -log10(1 - F(timeSinceLastHeartbeat) + * }}} + * where F is the cumulative distribution function of a normal distribution with mean + * and standard deviation estimated from historical heartbeat inter-arrival times. + * + * @param threshold A low threshold is prone to generate many wrong suspicions but ensures a quick detection in the event + * of a real crash. Conversely, a high threshold generates fewer mistakes but needs more time to detect + * actual crashes + * + * @param maxSampleSize Number of samples to use for calculation of mean and standard deviation of + * inter-arrival times. + * + * @param minStdDeviation Minimum standard deviation to use for the normal distribution used when calculating phi. + * Too low standard deviation might result in too much sensitivity for sudden, but normal, deviations + * in heartbeat inter arrival times. + * + * @param acceptableHeartbeatPause Duration corresponding to number of potentially lost/delayed + * heartbeats that will be accepted before considering it to be an anomaly. + * This margin is important to be able to survive sudden, occasional, pauses in heartbeat + * arrivals, due to for example garbage collect or network drop. + * + * @param firstHeartbeatEstimate Bootstrap the stats with heartbeats that corresponds to + * to this duration, with a with rather high standard deviation (since environment is unknown + * in the beginning) + * + * @param clock The clock, returning current time in milliseconds, but can be faked for testing + * purposes. It is only used for measuring intervals (duration). + */ +class PhiAccrualFailureDetector( + val threshold: Double, + val maxSampleSize: Int, + val minStdDeviation: FiniteDuration, + val acceptableHeartbeatPause: FiniteDuration, + val firstHeartbeatEstimate: FiniteDuration)( + implicit clock: Clock) extends FailureDetector { + + // guess statistics for first heartbeat, + // important so that connections with only one heartbeat becomes unavailable + private val firstHeartbeat: HeartbeatHistory = { + // bootstrap with 2 entries with rather high standard deviation + val mean = firstHeartbeatEstimate.toMillis + val stdDeviation = mean / 4 + HeartbeatHistory(maxSampleSize) :+ (mean - stdDeviation) :+ (mean + stdDeviation) + } + + private val acceptableHeartbeatPauseMillis = acceptableHeartbeatPause.toMillis + + /** + * Implement using optimistic lockless concurrency, all state is represented + * by this immutable case class and managed by an AtomicReference. + */ + private case class State( + history: HeartbeatHistory = firstHeartbeat, + timestamp: Option[Long] = None) + + private val state = new AtomicReference[State](State()) + + override def isAvailable: Boolean = phi < threshold + + @tailrec + final override def heartbeat(): Unit = { + + val timestamp = clock() + val oldState = state.get + + val newHistory = oldState.timestamp match { + case None ⇒ + // this is heartbeat from a new resource + // add starter records for this new resource + firstHeartbeat + case Some(latestTimestamp) ⇒ + // this is a known connection + val interval = timestamp - latestTimestamp + oldState.history :+ interval + } + + val newState = oldState.copy(history = newHistory, timestamp = Some(timestamp)) // record new timestamp + + // if we won the race then update else try again + if (!state.compareAndSet(oldState, newState)) heartbeat() // recur + } + + /** + * The suspicion level of the accrual failure detector. + * + * If a connection does not have any records in failure detector then it is + * considered healthy. + */ + def phi: Double = { + val oldState = state.get + val oldTimestamp = oldState.timestamp + + if (oldTimestamp.isEmpty) 0.0 // treat unmanaged connections, e.g. with zero heartbeats, as healthy connections + else { + val timeDiff = clock() - oldTimestamp.get + + val history = oldState.history + val mean = history.mean + val stdDeviation = ensureValidStdDeviation(history.stdDeviation) + + val φ = phi(timeDiff, mean + acceptableHeartbeatPauseMillis, stdDeviation) + + φ + } + } + + private[akka] def phi(timeDiff: Long, mean: Double, stdDeviation: Double): Double = { + val cdf = cumulativeDistributionFunction(timeDiff, mean, stdDeviation) + -math.log10(1.0 - cdf) + } + + private val minStdDeviationMillis = minStdDeviation.toMillis + + private def ensureValidStdDeviation(stdDeviation: Double): Double = math.max(stdDeviation, minStdDeviationMillis) + + /** + * Cumulative distribution function for N(mean, stdDeviation) normal distribution. + * This is an approximation defined in β Mathematics Handbook. + */ + private[akka] def cumulativeDistributionFunction(x: Double, mean: Double, stdDeviation: Double): Double = { + val y = (x - mean) / stdDeviation + // Cumulative distribution function for N(0, 1) + 1.0 / (1.0 + math.exp(-y * (1.5976 + 0.070566 * y * y))) + } +} + +private[akka] object HeartbeatHistory { + + /** + * Create an empty HeartbeatHistory, without any history. + * Can only be used as starting point for appending intervals. + * The stats (mean, variance, stdDeviation) are not defined for + * for empty HeartbeatHistory, i.e. throws AritmeticException. + */ + def apply(maxSampleSize: Int): HeartbeatHistory = HeartbeatHistory( + maxSampleSize = maxSampleSize, + intervals = IndexedSeq.empty, + intervalSum = 0L, + squaredIntervalSum = 0L) + +} + +/** + * Holds the heartbeat statistics for a specific node Address. + * It is capped by the number of samples specified in `maxSampleSize`. + * + * The stats (mean, variance, stdDeviation) are not defined for + * for empty HeartbeatHistory, i.e. throws AritmeticException. + */ +private[akka] case class HeartbeatHistory private ( + maxSampleSize: Int, + intervals: IndexedSeq[Long], + intervalSum: Long, + squaredIntervalSum: Long) { + + if (maxSampleSize < 1) + throw new IllegalArgumentException(s"maxSampleSize must be >= 1, got [$maxSampleSize]") + if (intervalSum < 0L) + throw new IllegalArgumentException(s"intervalSum must be >= 0, got [$intervalSum]") + if (squaredIntervalSum < 0L) + throw new IllegalArgumentException(s"squaredIntervalSum must be >= 0, got [$squaredIntervalSum]") + + def mean: Double = intervalSum.toDouble / intervals.size + + def variance: Double = (squaredIntervalSum.toDouble / intervals.size) - (mean * mean) + + def stdDeviation: Double = math.sqrt(variance) + + @tailrec + final def :+(interval: Long): HeartbeatHistory = { + if (intervals.size < maxSampleSize) + HeartbeatHistory( + maxSampleSize, + intervals = intervals :+ interval, + intervalSum = intervalSum + interval, + squaredIntervalSum = squaredIntervalSum + pow2(interval)) + else + dropOldest :+ interval // recur + } + + private def dropOldest: HeartbeatHistory = HeartbeatHistory( + maxSampleSize, + intervals = intervals drop 1, + intervalSum = intervalSum - intervals.head, + squaredIntervalSum = squaredIntervalSum - pow2(intervals.head)) + + private def pow2(x: Long) = x * x +} diff --git a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala index 1d9ad9edc2..dc49391a2f 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala @@ -77,7 +77,7 @@ class RemoteActorRefProvider( }).get } - _log = Logging(eventStream, "RemoteActorRefProvider(" + transport.address + ")") + _log = Logging(eventStream, "RemoteActorRefProvider") // this enables reception of remote requests _transport.start() @@ -108,15 +108,15 @@ class RemoteActorRefProvider( * address below “remote”, including the current system’s identification * as “sys@host:port” (typically; it will use whatever the remote * transport uses). This means that on a path up an actor tree each node - * change introduces one layer or “remote/sys@host:port/” within the URI. + * change introduces one layer or “remote/scheme/sys@host:port/” within the URI. * * Example: * - * akka://sys@home:1234/remote/sys@remote:6667/remote/sys@other:3333/user/a/b/c + * akka://sys@home:1234/remote/akka/sys@remote:6667/remote/akka/sys@other:3333/user/a/b/c * - * means that the logical parent originates from “sys@other:3333” with - * one child (may be “a” or “b”) being deployed on “sys@remote:6667” and - * finally either “b” or “c” being created on “sys@home:1234”, where + * means that the logical parent originates from “akka://sys@other:3333” with + * one child (may be “a” or “b”) being deployed on “akka://sys@remote:6667” and + * finally either “b” or “c” being created on “akka://sys@home:1234”, where * this whole thing actually resides. Thus, the logical path is * “/user/a/b/c” and the physical path contains all remote placement * information. @@ -129,7 +129,7 @@ class RemoteActorRefProvider( def lookupRemotes(p: Iterable[String]): Option[Deploy] = { p.headOption match { case None ⇒ None - case Some("remote") ⇒ lookupRemotes(p.drop(2)) + case Some("remote") ⇒ lookupRemotes(p.drop(3)) case Some("user") ⇒ deployer.lookup(p.drop(1)) case Some(_) ⇒ None } @@ -154,11 +154,13 @@ class RemoteActorRefProvider( Iterator(props.deploy) ++ deployment.iterator reduce ((a, b) ⇒ b withFallback a) match { case d @ Deploy(_, _, _, RemoteScope(addr)) ⇒ - if (addr == rootPath.address || addr == transport.address) { + if (addr == rootPath.address || transport.addresses(addr)) { local.actorOf(system, props, supervisor, path, false, deployment.headOption, false, async) } else { - val rpath = RootActorPath(addr) / "remote" / transport.address.hostPort / path.elements - new RemoteActorRef(this, transport, rpath, supervisor, Some(props), Some(d)) + val localAddress = transport.localAddressForRemote(addr) + val rpath = RootActorPath(addr) / "remote" / localAddress.protocol / localAddress.hostPort / path.elements + useActorOnNode(rpath, props, d, supervisor) + new RemoteActorRef(this, transport, localAddress, rpath, supervisor, Some(props), Some(d)) } case _ ⇒ local.actorOf(system, props, supervisor, path, systemService, deployment.headOption, false, async) @@ -167,13 +169,13 @@ class RemoteActorRefProvider( } def actorFor(path: ActorPath): InternalActorRef = - if (path.address == rootPath.address || path.address == transport.address) actorFor(rootGuardian, path.elements) - else new RemoteActorRef(this, transport, path, Nobody, props = None, deploy = None) + if (path.address == rootPath.address || transport.addresses(path.address)) actorFor(rootGuardian, path.elements) + else new RemoteActorRef(this, transport, transport.localAddressForRemote(path.address), path, Nobody, props = None, deploy = None) def actorFor(ref: InternalActorRef, path: String): InternalActorRef = path match { case ActorPathExtractor(address, elems) ⇒ - if (address == rootPath.address || address == transport.address) actorFor(rootGuardian, elems) - else new RemoteActorRef(this, transport, new RootActorPath(address) / elems, Nobody, props = None, deploy = None) + if (address == rootPath.address || transport.addresses(address)) actorFor(rootGuardian, elems) + else new RemoteActorRef(this, transport, transport.localAddressForRemote(address), new RootActorPath(address) / elems, Nobody, props = None, deploy = None) case _ ⇒ local.actorFor(ref, path) } @@ -190,12 +192,11 @@ class RemoteActorRefProvider( } def getExternalAddressFor(addr: Address): Option[Address] = { - val ta = transport.address val ra = rootPath.address addr match { - case `ta` | `ra` ⇒ Some(rootPath.address) - case Address("akka", _, Some(_), Some(_)) ⇒ Some(transport.address) - case _ ⇒ None + case a if (a eq ra) || transport.addresses(a) ⇒ Some(rootPath.address) + case Address(_, _, Some(_), Some(_)) ⇒ Some(transport.localAddressForRemote(addr)) + case _ ⇒ None } } } @@ -211,6 +212,7 @@ private[akka] trait RemoteRef extends ActorRefScope { private[akka] class RemoteActorRef private[akka] ( val provider: RemoteActorRefProvider, remote: RemoteTransport, + val localAddressToUse: Address, val path: ActorPath, val getParent: InternalActorRef, props: Option[Props], @@ -222,7 +224,7 @@ private[akka] class RemoteActorRef private[akka] ( s.headOption match { case None ⇒ this case Some("..") ⇒ getParent getChild name - case _ ⇒ new RemoteActorRef(provider, remote, path / s, Nobody, props = None, deploy = None) + case _ ⇒ new RemoteActorRef(provider, remote, localAddressToUse, path / s, Nobody, props = None, deploy = None) } } @@ -256,4 +258,4 @@ private[akka] class RemoteActorRef private[akka] ( @throws(classOf[java.io.ObjectStreamException]) private def writeReplace(): AnyRef = SerializedActorRef(path) -} +} \ No newline at end of file diff --git a/akka-remote/src/main/scala/akka/remote/RemoteTransport.scala b/akka-remote/src/main/scala/akka/remote/RemoteTransport.scala index 09db024caa..e4ee103735 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteTransport.scala @@ -176,7 +176,14 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re /** * Address to be used in RootActorPath of refs generated for this transport. */ - def address: Address + def addresses: Set[Address] + + /** + * Resolves the correct local address to be used for contacting the given remote address + * @param remote the remote address + * @return the local address to be used for the given remote address + */ + def localAddressForRemote(remote: Address): Address /** * Start up the transport, i.e. enable incoming connections. @@ -184,14 +191,14 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re def start(): Unit /** - * Shuts down a specific client connected to the supplied remote address returns true if successful + * Attempts to shut down a specific client connected to the supplied remote address */ - def shutdownClientConnection(address: Address): Boolean + def shutdownClientConnection(address: Address): Unit /** - * Restarts a specific client connected to the supplied remote address, but only if the client is not shut down + * Attempts to restart a specific client connected to the supplied remote address, but only if the client is not shut down */ - def restartClientConnection(address: Address): Boolean + def restartClientConnection(address: Address): Unit /** * Sends the given message to the recipient supplying the sender if any @@ -206,11 +213,6 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re if (logRemoteLifeCycleEvents) log.log(message.logLevel, "{}", message) } - /** - * Returns this RemoteTransports Address' textual representation - */ - override def toString: String = address.toString - /** * A Logger that can be used to log issues that may occur */ @@ -242,7 +244,7 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re * Serializes the ActorRef instance into a Protocol Buffers (protobuf) Message. */ def toRemoteActorRefProtocol(actor: ActorRef): ActorRefProtocol = - ActorRefProtocol.newBuilder.setPath(actor.path.toStringWithAddress(address)).build + ActorRefProtocol.newBuilder.setPath(actor.path.toStringWithAddress(addresses.head)).build /** * Returns a new RemoteMessageProtocol containing the serialized representation of the given parameters. @@ -251,7 +253,7 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re val messageBuilder = RemoteMessageProtocol.newBuilder.setRecipient(toRemoteActorRefProtocol(recipient)) if (senderOption.isDefined) messageBuilder.setSender(toRemoteActorRefProtocol(senderOption.get)) - Serialization.currentTransportAddress.withValue(address) { + Serialization.currentTransportAddress.withValue(addresses.head) { messageBuilder.setMessage(MessageSerializer.serialize(system, message.asInstanceOf[AnyRef])) } @@ -289,16 +291,16 @@ abstract class RemoteTransport(val system: ExtendedActorSystem, val provider: Re case r @ (_: RemoteRef | _: RepointableRef) if !r.isLocal && !useUntrustedMode ⇒ if (provider.remoteSettings.LogReceive) log.debug("received remote-destined message {}", remoteMessage) remoteMessage.originalReceiver match { - case AddressFromURIString(address) if address == provider.transport.address ⇒ + case AddressFromURIString(address) if provider.transport.addresses(address) ⇒ // if it was originally addressed to us but is in fact remote from our point of view (i.e. remote-deployed) r.!(remoteMessage.payload)(remoteMessage.sender) case r ⇒ - log.debug("dropping message {} for non-local recipient {} arriving at {} inbound address is {}", - remoteMessage.payloadClass, r, address, provider.transport.address) + log.debug("dropping message {} for non-local recipient {} arriving at {} inbound addresses are {}", + remoteMessage.payloadClass, r, addresses, provider.transport.addresses) } case r ⇒ - log.debug("dropping message {} for unknown recipient {} arriving at {} inbound address is {}", - remoteMessage.payloadClass, r, address, provider.transport.address) + log.debug("dropping message {} for unknown recipient {} arriving at {} inbound addresses are {}", + remoteMessage.payloadClass, r, addresses, provider.transport.addresses) } } } diff --git a/akka-remote/src/main/scala/akka/remote/Remoting.scala b/akka-remote/src/main/scala/akka/remote/Remoting.scala new file mode 100644 index 0000000000..22a6cdd0b0 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/Remoting.scala @@ -0,0 +1,377 @@ +package akka.remote + +import akka.actor.SupervisorStrategy._ +import akka.actor._ +import akka.event.{ Logging, LoggingAdapter } +import akka.pattern.gracefulStop +import akka.remote.EndpointManager.Listen +import akka.remote.EndpointManager.Send +import akka.remote.transport.Transport.InboundAssociation +import akka.remote.transport._ +import akka.util.Timeout +import com.typesafe.config.Config +import scala.collection.immutable.{ Seq, HashMap } +import scala.concurrent.duration._ +import scala.concurrent.{ Promise, Await, Future } +import scala.util.control.NonFatal +import java.net.URLEncoder +import java.util.concurrent.TimeoutException +import scala.util.{ Failure, Success } + +class RemotingSettings(config: Config) { + + import config._ + import scala.collection.JavaConverters._ + + val LogLifecycleEvents: Boolean = getBoolean("akka.remoting.log-remote-lifecycle-events") + + val ShutdownTimeout: FiniteDuration = Duration(getMilliseconds("akka.remoting.shutdown-timeout"), MILLISECONDS) + + val StartupTimeout: FiniteDuration = Duration(getMilliseconds("akka.remoting.startup-timeout"), MILLISECONDS) + + val RetryLatchClosedFor: Long = getMilliseconds("akka.remoting.retry-latch-closed-for") + + val UsePassiveConnections: Boolean = getBoolean("akka.remoting.use-passive-connections") + + val MaximumRetriesInWindow: Int = getInt("akka.remoting.maximum-retries-in-window") + + val RetryWindow: FiniteDuration = Duration(getMilliseconds("akka.remoting.retry-window"), MILLISECONDS) + + val BackoffPeriod: FiniteDuration = + Duration(getMilliseconds("akka.remoting.backoff-interval"), MILLISECONDS) + + val Transports: List[(String, Config)] = + config.getConfigList("akka.remoting.transports").asScala.map { + conf ⇒ (conf.getString("transport-class"), conf.getConfig("settings")) + }.toList +} + +private[remote] object Remoting { + + val EndpointManagerName = "remoteTransportHeadActor" + + def localAddressForRemote(transportMapping: Map[String, Set[(Transport, Address)]], remote: Address): Address = { + + transportMapping.get(remote.protocol) match { + case Some(transports) ⇒ + val responsibleTransports = transports.filter(_._1.isResponsibleFor(remote)) + + responsibleTransports.size match { + case 0 ⇒ + throw new RemoteTransportException( + s"No transport is responsible for address: ${remote} although protocol ${remote.protocol} is available." + + " Make sure at least one transport is configured to be responsible for the address.", + null) + + case 1 ⇒ + responsibleTransports.head._2 + + case _ ⇒ + throw new RemoteTransportException( + s"Multiple transports are available for ${remote}: ${responsibleTransports.mkString(",")}. " + + "Remoting cannot decide which transport to use to reach the remote system. Change your configuration " + + "so that only one transport is responsible for the address.", + null) + } + case None ⇒ throw new RemoteTransportException(s"No transport is loaded for protocol: ${remote.protocol}", null) + } + } + +} + +private[remote] class Remoting(_system: ExtendedActorSystem, _provider: RemoteActorRefProvider) extends RemoteTransport(_system, _provider) { + + @volatile private var endpointManager: ActorRef = _ + @volatile var transportMapping: Map[String, Set[(Transport, Address)]] = _ + @volatile var addresses: Set[Address] = _ + private val settings = new RemotingSettings(provider.remoteSettings.config) + + override def localAddressForRemote(remote: Address): Address = Remoting.localAddressForRemote(transportMapping, remote) + + val log: LoggingAdapter = Logging(system.eventStream, "Remoting") + val eventPublisher = new EventPublisher(system, log, settings.LogLifecycleEvents) + + private def notifyError(msg: String, cause: Throwable): Unit = + eventPublisher.notifyListeners(RemotingErrorEvent(new RemoteTransportException(msg, cause))) + + override def shutdown(): Unit = { + if (endpointManager != null) { + try { + val stopped: Future[Boolean] = gracefulStop(endpointManager, settings.ShutdownTimeout)(system) + + if (Await.result(stopped, settings.ShutdownTimeout)) { + eventPublisher.notifyListeners(RemotingShutdownEvent) + } + + } catch { + case e: TimeoutException ⇒ notifyError("Shutdown timed out.", e) + case NonFatal(e) ⇒ notifyError("Shutdown failed.", e) + } finally { + endpointManager = null + } + } + } + + // Start assumes that it cannot be followed by another start() without having a shutdown() first + override def start(): Unit = { + if (endpointManager eq null) { + log.info("Starting remoting") + endpointManager = system.asInstanceOf[ActorSystemImpl].systemActorOf( + Props(new EndpointManager(provider.remoteSettings.config, log)), Remoting.EndpointManagerName) + + implicit val timeout = new Timeout(settings.StartupTimeout) + + try { + val addressesPromise: Promise[Set[(Transport, Address)]] = Promise() + endpointManager ! Listen(addressesPromise) + + val transports: Set[(Transport, Address)] = Await.result(addressesPromise.future, timeout.duration) + transportMapping = transports.groupBy { case (transport, _) ⇒ transport.schemeIdentifier }.mapValues { + _.toSet + } + + addresses = transports.map { _._2 }.toSet + eventPublisher.notifyListeners(RemotingListenEvent(addresses)) + + } catch { + case e: TimeoutException ⇒ notifyError("Startup timed out", e) + case NonFatal(e) ⇒ notifyError("Startup failed", e) + } + + } else { + log.warning("Remoting was already started. Ignoring start attempt.") + } + } + + // TODO: this is called in RemoteActorRefProvider to handle the lifecycle of connections (clients) + // which is not how things work in the new remoting + override def shutdownClientConnection(address: Address): Unit = { + // Ignore + } + + // TODO: this is never called anywhere, should be taken out from RemoteTransport API + override def restartClientConnection(address: Address): Unit = { + // Ignore + } + + override def send(message: Any, senderOption: Option[ActorRef], recipient: RemoteActorRef): Unit = { + endpointManager.tell(Send(message, senderOption, recipient), sender = Actor.noSender) + } + + // Not used anywhere only to keep compatibility with RemoteTransport interface + protected def useUntrustedMode: Boolean = provider.remoteSettings.UntrustedMode + + // Not used anywhere only to keep compatibility with RemoteTransport interface + protected def logRemoteLifeCycleEvents: Boolean = provider.remoteSettings.LogRemoteLifeCycleEvents + +} + +private[remote] object EndpointManager { + + sealed trait RemotingCommand + case class Listen(addressesPromise: Promise[Set[(Transport, Address)]]) extends RemotingCommand + + case class Send(message: Any, senderOption: Option[ActorRef], recipient: RemoteActorRef) extends RemotingCommand { + override def toString = s"Remote message $senderOption -> $recipient" + } + + sealed trait EndpointPolicy + case class Pass(endpoint: ActorRef) extends EndpointPolicy + case class Latched(timeOfFailure: Long) extends EndpointPolicy + case class Quarantined(reason: Throwable) extends EndpointPolicy + + case object Prune + + // Not threadsafe -- only to be used in HeadActor + private[EndpointManager] class EndpointRegistry { + @volatile private var addressToEndpointAndPolicy = HashMap[Address, EndpointPolicy]() + @volatile private var endpointToAddress = HashMap[ActorRef, Address]() + + def getEndpointWithPolicy(address: Address): Option[EndpointPolicy] = addressToEndpointAndPolicy.get(address) + + def prune(pruneAge: Long): Unit = { + addressToEndpointAndPolicy = addressToEndpointAndPolicy.filter { + case (_, Pass(_)) ⇒ true + case (_, Latched(timeOfFailure)) ⇒ timeOfFailure + pruneAge > System.nanoTime() + } + } + + def registerEndpoint(address: Address, endpoint: ActorRef): ActorRef = { + addressToEndpointAndPolicy = addressToEndpointAndPolicy + (address -> Pass(endpoint)) + endpointToAddress = endpointToAddress + (endpoint -> address) + endpoint + } + + def markFailed(endpoint: ActorRef, timeOfFailure: Long): Unit = { + addressToEndpointAndPolicy += endpointToAddress(endpoint) -> Latched(timeOfFailure) + endpointToAddress = endpointToAddress - endpoint + } + + def markQuarantine(address: Address, reason: Throwable): Unit = + addressToEndpointAndPolicy += address -> Quarantined(reason) + + def removeIfNotLatched(endpoint: ActorRef): Unit = { + endpointToAddress.get(endpoint) foreach { address ⇒ + addressToEndpointAndPolicy.get(address) foreach { policy ⇒ + policy match { + case Pass(_) ⇒ + addressToEndpointAndPolicy = addressToEndpointAndPolicy - address + endpointToAddress = endpointToAddress - endpoint + case _ ⇒ + } + } + } + } + } +} + +private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends Actor { + + import EndpointManager._ + import context.dispatcher + + val settings = new RemotingSettings(conf) + val extendedSystem = context.system.asInstanceOf[ExtendedActorSystem] + var endpointId: Long = 0L + + val eventPublisher = new EventPublisher(context.system, log, settings.LogLifecycleEvents) + + // Mapping between addresses and endpoint actors. If passive connections are turned off, incoming connections + // will be not part of this map! + val endpoints = new EndpointRegistry + // Mapping between transports and the local addresses they listen to + var transportMapping: Map[Address, Transport] = Map() + + val retryLatchEnabled = settings.RetryLatchClosedFor > 0L + val pruneInterval: Long = if (retryLatchEnabled) settings.RetryLatchClosedFor * 2L else 0L + val pruneTimerCancellable: Option[Cancellable] = if (retryLatchEnabled) + Some(context.system.scheduler.schedule(pruneInterval milliseconds, pruneInterval milliseconds, self, Prune)) + else None + + override val supervisorStrategy = OneForOneStrategy(settings.MaximumRetriesInWindow, settings.RetryWindow) { + case InvalidAssociation(localAddress, remoteAddress, e) ⇒ + endpoints.markQuarantine(remoteAddress, e) + Stop + + case NonFatal(e) ⇒ + if (!retryLatchEnabled) + // This strategy keeps all the messages in the stash of the endpoint so restart will transfer the queue + // to the restarted endpoint -- thus no messages are lost + Restart + else { + // This strategy throws away all the messages enqueued in the endpoint (in its stash), registers the time of failure, + // keeps throwing away messages until the retry latch becomes open (time specified in RetryLatchClosedFor) + endpoints.markFailed(sender, System.nanoTime()) + Stop + } + } + + def receive = { + case Listen(addressesPromise) ⇒ try initializeTransports(addressesPromise) catch { + case NonFatal(e) ⇒ + addressesPromise.failure(e) + context.stop(self) + } + } + + val accepting: Receive = { + case s @ Send(message, senderOption, recipientRef) ⇒ + val recipientAddress = recipientRef.path.address + + endpoints.getEndpointWithPolicy(recipientAddress) match { + case Some(Pass(endpoint)) ⇒ endpoint ! s + case Some(Latched(timeOfFailure)) ⇒ if (retryLatchOpen(timeOfFailure)) + createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) ! s + else extendedSystem.deadLetters ! message + case Some(Quarantined(_)) ⇒ extendedSystem.deadLetters ! message + case None ⇒ createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) ! s + + } + + case InboundAssociation(handle) ⇒ + val endpoint = createEndpoint(handle.remoteAddress, handle.localAddress, Some(handle)) + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + if (settings.UsePassiveConnections) endpoints.registerEndpoint(handle.localAddress, endpoint) + case Terminated(endpoint) ⇒ endpoints.removeIfNotLatched(endpoint) + case Prune ⇒ endpoints.prune(settings.RetryLatchClosedFor) + } + + private def initializeTransports(addressesPromise: Promise[Set[(Transport, Address)]]): Unit = { + val transports = for ((fqn, config) ← settings.Transports) yield { + + val args = Seq(classOf[ExtendedActorSystem] -> context.system, classOf[Config] -> config) + + val wrappedTransport = context.system.asInstanceOf[ActorSystemImpl].dynamicAccess + .createInstanceFor[Transport](fqn, args).recover({ + + case exception ⇒ throw new IllegalArgumentException( + (s"Cannot instantiate transport [$fqn]. " + + "Make sure it extends [akka.remote.transport.Transport] and has constructor with " + + "[akka.actor.ExtendedActorSystem] and [com.typesafe.config.Config] parameters"), exception) + + }).get + + new AkkaProtocolTransport(wrappedTransport, context.system, new AkkaProtocolSettings(conf), AkkaPduProtobufCodec) + + } + + val listens: Future[Seq[(Transport, (Address, Promise[ActorRef]))]] = Future.sequence( + transports.map { transport ⇒ transport.listen.map { transport -> _ } }) + + listens.onComplete { + case Success(results) ⇒ + val transportsAndAddresses = (for ((transport, (address, promise)) ← results) yield { + promise.success(self) + transport -> address + }).toSet + addressesPromise.success(transportsAndAddresses) + + context.become(accepting) + + transportMapping = HashMap() ++ results.groupBy { case (_, (transportAddress, _)) ⇒ transportAddress }.map { + case (a, t) ⇒ + if (t.size > 1) + throw new RemoteTransportException(s"There are more than one transports listening on local address $a", null) + + a -> t.head._1 + } + + case Failure(reason) ⇒ addressesPromise.failure(reason) + } + } + + private def createEndpoint(remoteAddress: Address, + localAddress: Address, + handleOption: Option[AssociationHandle]): ActorRef = { + assert(transportMapping.contains(localAddress)) + val id = endpointId + endpointId += 1L + + val endpoint = context.actorOf(Props( + new EndpointWriter( + handleOption, + localAddress, + remoteAddress, + transportMapping(localAddress), + settings, + AkkaPduProtobufCodec)) + .withDispatcher("akka.remoting.writer-dispatcher"), + "endpointWriter-" + URLEncoder.encode(remoteAddress.toString, "utf-8") + "-" + endpointId) + + endpoints.registerEndpoint(remoteAddress, endpoint) + } + + private def retryLatchOpen(timeOfFailure: Long): Boolean = (timeOfFailure + settings.RetryLatchClosedFor) < System.nanoTime() + + override def postStop(): Unit = { + pruneTimerCancellable.foreach { _.cancel() } + transportMapping.values foreach { transport ⇒ + try transport.shutdown() + catch { + case NonFatal(e) ⇒ + log.error(e, s"Unable to shut down the underlying Transport: $transport") + } + } + } + +} \ No newline at end of file diff --git a/akka-remote/src/main/scala/akka/remote/RemotingLifecycle.scala b/akka-remote/src/main/scala/akka/remote/RemotingLifecycle.scala new file mode 100644 index 0000000000..3990eca79d --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/RemotingLifecycle.scala @@ -0,0 +1,71 @@ +package akka.remote + +import akka.event.{ LoggingAdapter, Logging } +import akka.actor.{ ActorSystem, Address } +import scala.beans.BeanProperty +import java.util.{ Set ⇒ JSet } +import scala.collection.JavaConverters.setAsJavaSetConverter + +trait RemotingLifecycleEvent extends Serializable { + def logLevel: Logging.LogLevel +} + +trait AssociationEvent extends RemotingLifecycleEvent { + def localAddress: Address + def remoteAddress: Address + def inbound: Boolean + protected def eventName: String + final def getRemoteAddress: Address = remoteAddress + final def getLocalAddress: Address = localAddress + final def isInbound: Boolean = inbound + override def toString: String = s"$eventName [$localAddress]${if (inbound) " <- " else " -> "}[$remoteAddress]" +} + +case class AssociatedEvent( + localAddress: Address, + remoteAddress: Address, + inbound: Boolean) extends AssociationEvent { + protected override val eventName: String = "Associated" + override def logLevel: Logging.LogLevel = Logging.DebugLevel +} + +case class DisassociatedEvent( + localAddress: Address, + remoteAddress: Address, + inbound: Boolean) extends AssociationEvent { + protected override val eventName: String = "Disassociated" + override def logLevel: Logging.LogLevel = Logging.DebugLevel +} + +case class AssociationErrorEvent( + @BeanProperty cause: Throwable, + localAddress: Address, + remoteAddress: Address, + inbound: Boolean) extends AssociationEvent { + protected override val eventName: String = "AssociationError" + override def logLevel: Logging.LogLevel = Logging.ErrorLevel + override def toString: String = s"${super.toString}: Error[${Logging.stackTraceFor(cause)}]" +} + +case class RemotingListenEvent(listenAddresses: Set[Address]) extends RemotingLifecycleEvent { + final def getListenAddresses: JSet[Address] = listenAddresses.asJava + override def logLevel: Logging.LogLevel = Logging.InfoLevel + override def toString: String = "Remoting now listens on addresses: " + listenAddresses.mkString("[", ", ", "]") +} + +case object RemotingShutdownEvent extends RemotingLifecycleEvent { + override def logLevel: Logging.LogLevel = Logging.InfoLevel + override val toString: String = "Remoting shut down" +} + +case class RemotingErrorEvent(@BeanProperty cause: Throwable) extends RemotingLifecycleEvent { + override def logLevel: Logging.LogLevel = Logging.ErrorLevel + override def toString: String = s"Remoting error: [${Logging.stackTraceFor(cause)}]" +} + +class EventPublisher(system: ActorSystem, log: LoggingAdapter, logEvents: Boolean) { + def notifyListeners(message: RemotingLifecycleEvent): Unit = { + system.eventStream.publish(message) + if (logEvents) log.log(message.logLevel, "{}", message) + } +} \ No newline at end of file diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index 6e36c63024..6d2679e251 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -37,6 +37,9 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider val settings = new NettySettings(remoteSettings.config.getConfig("akka.remote.netty"), remoteSettings.systemName) + // Workaround to emulate the support of multiple local addresses + override def localAddressForRemote(remote: Address): Address = addresses.head + // TODO replace by system.scheduler val timer: HashedWheelTimer = new HashedWheelTimer(system.threadFactory) @@ -73,7 +76,7 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider * actually dispatches the received messages to the local target actors). */ def defaultStack(withTimeout: Boolean, isClient: Boolean): Seq[ChannelHandler] = - (if (settings.EnableSSL) List(NettySSLSupport(settings, NettyRemoteTransport.this.log, isClient)) else Nil) ::: + (if (settings.EnableSSL) List(NettySSLSupport(settings.SslSettings, NettyRemoteTransport.this.log, isClient)) else Nil) ::: (if (withTimeout) List(timeout) else Nil) ::: msgFormat ::: authenticator ::: @@ -162,7 +165,7 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider * the normal one, e.g. for inserting security hooks. Get this transport’s * address from `this.address`. */ - protected def createClient(recipient: Address): RemoteClient = new ActiveRemoteClient(this, recipient, address) + protected def createClient(recipient: Address): RemoteClient = new ActiveRemoteClient(this, recipient, addresses.head) // the address is set in start() or from the RemoteServerHandler, whichever comes first private val _address = new AtomicReference[Address] @@ -174,9 +177,11 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider _address.compareAndSet(null, Address("akka", remoteSettings.systemName, settings.Hostname, addr.getPort)) } + // Workaround to emulate the support of multiple local addresses + def addresses = Set(address) def address = _address.get - lazy val log = Logging(system.eventStream, "NettyRemoteTransport(" + address + ")") + lazy val log = Logging(system.eventStream, "NettyRemoteTransport(" + addresses + ")") def start(): Unit = { server.start() @@ -271,7 +276,7 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider } } - def shutdownClientConnection(remoteAddress: Address): Boolean = { + def shutdownClientConnection(remoteAddress: Address): Unit = { clientsLock.writeLock().lock() try { remoteClients.remove(remoteAddress) match { @@ -283,7 +288,7 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider } } - def restartClientConnection(remoteAddress: Address): Boolean = { + def restartClientConnection(remoteAddress: Address): Unit = { clientsLock.readLock().lock() try { remoteClients.get(remoteAddress) match { diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettySSLSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettySSLSupport.scala index bc4f39dca4..1d4fac813f 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettySSLSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettySSLSupport.scala @@ -11,6 +11,60 @@ import akka.event.LoggingAdapter import java.io.{ IOException, FileNotFoundException, FileInputStream } import akka.remote.security.provider.AkkaProvider import java.security._ +import com.typesafe.config.Config +import scala.collection.JavaConverters._ +import scala.Some +import akka.ConfigurationException + +private[akka] class SslSettings(config: Config) { + import config._ + + val SSLKeyStore = getString("key-store") match { + case "" ⇒ None + case keyStore ⇒ Some(keyStore) + } + + val SSLTrustStore = getString("trust-store") match { + case "" ⇒ None + case trustStore ⇒ Some(trustStore) + } + + val SSLKeyStorePassword = getString("key-store-password") match { + case "" ⇒ None + case password ⇒ Some(password) + } + + val SSLTrustStorePassword = getString("trust-store-password") match { + case "" ⇒ None + case password ⇒ Some(password) + } + + val SSLEnabledAlgorithms = iterableAsScalaIterableConverter(getStringList("enabled-algorithms")).asScala.toSet[String] + + val SSLProtocol = getString("protocol") match { + case "" ⇒ None + case protocol ⇒ Some(protocol) + } + + val SSLRandomSource = getString("sha1prng-random-source") match { + case "" ⇒ None + case path ⇒ Some(path) + } + + val SSLRandomNumberGenerator = getString("random-number-generator") match { + case "" ⇒ None + case rng ⇒ Some(rng) + } + + if (SSLProtocol.isEmpty) throw new ConfigurationException( + "Configuration option 'akka.remote.netty.ssl.enable is turned on but no protocol is defined in 'akka.remote.netty.ssl.protocol'.") + if (SSLKeyStore.isEmpty && SSLTrustStore.isEmpty) throw new ConfigurationException( + "Configuration option 'akka.remote.netty.ssl.enable is turned on but no key/trust store is defined in 'akka.remote.netty.ssl.key-store' / 'akka.remote.netty.ssl.trust-store'.") + if (SSLKeyStore.isDefined && SSLKeyStorePassword.isEmpty) throw new ConfigurationException( + "Configuration option 'akka.remote.netty.ssl.key-store' is defined but no key-store password is defined in 'akka.remote.netty.ssl.key-store-password'.") + if (SSLTrustStore.isDefined && SSLTrustStorePassword.isEmpty) throw new ConfigurationException( + "Configuration option 'akka.remote.netty.ssl.trust-store' is defined but no trust-store password is defined in 'akka.remote.netty.ssl.trust-store-password'.") +} /** * Used for adding SSL support to Netty pipeline @@ -23,7 +77,7 @@ private[akka] object NettySSLSupport { /** * Construct a SSLHandler which can be inserted into a Netty server/client pipeline */ - def apply(settings: NettySettings, log: LoggingAdapter, isClient: Boolean): SslHandler = + def apply(settings: SslSettings, log: LoggingAdapter, isClient: Boolean): SslHandler = if (isClient) initializeClientSSL(settings, log) else initializeServerSSL(settings, log) def initializeCustomSecureRandom(rngName: Option[String], sourceOfRandomness: Option[String], log: LoggingAdapter): SecureRandom = { @@ -57,10 +111,10 @@ private[akka] object NettySSLSupport { rng } - def initializeClientSSL(settings: NettySettings, log: LoggingAdapter): SslHandler = { + def initializeClientSSL(settings: SslSettings, log: LoggingAdapter): SslHandler = { log.debug("Client SSL is enabled, initialising ...") - def constructClientContext(settings: NettySettings, log: LoggingAdapter, trustStorePath: String, trustStorePassword: String, protocol: String): Option[SSLContext] = + def constructClientContext(settings: SslSettings, log: LoggingAdapter, trustStorePath: String, trustStorePassword: String, protocol: String): Option[SSLContext] = try { val rng = initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log) val trustManagers: Array[TrustManager] = { @@ -106,10 +160,10 @@ private[akka] object NettySSLSupport { } } - def initializeServerSSL(settings: NettySettings, log: LoggingAdapter): SslHandler = { + def initializeServerSSL(settings: SslSettings, log: LoggingAdapter): SslHandler = { log.debug("Server SSL is enabled, initialising ...") - def constructServerContext(settings: NettySettings, log: LoggingAdapter, keyStorePath: String, keyStorePassword: String, protocol: String): Option[SSLContext] = + def constructServerContext(settings: SslSettings, log: LoggingAdapter, keyStorePath: String, keyStorePassword: String, protocol: String): Option[SSLContext] = try { val rng = initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log) val factory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) diff --git a/akka-remote/src/main/scala/akka/remote/netty/Settings.scala b/akka-remote/src/main/scala/akka/remote/netty/Settings.scala index c9fb4aff9a..932417391a 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Settings.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Settings.scala @@ -88,55 +88,8 @@ private[akka] class NettySettings(config: Config, val systemName: String) { case sz ⇒ sz } - val SSLKeyStore = getString("ssl.key-store") match { - case "" ⇒ None - case keyStore ⇒ Some(keyStore) - } + val SslSettings = new SslSettings(config.getConfig("ssl")) - val SSLTrustStore = getString("ssl.trust-store") match { - case "" ⇒ None - case trustStore ⇒ Some(trustStore) - } + val EnableSSL = getBoolean("ssl.enable") - val SSLKeyStorePassword = getString("ssl.key-store-password") match { - case "" ⇒ None - case password ⇒ Some(password) - } - - val SSLTrustStorePassword = getString("ssl.trust-store-password") match { - case "" ⇒ None - case password ⇒ Some(password) - } - - val SSLEnabledAlgorithms = iterableAsScalaIterableConverter(getStringList("ssl.enabled-algorithms")).asScala.toSet[String] - - val SSLProtocol = getString("ssl.protocol") match { - case "" ⇒ None - case protocol ⇒ Some(protocol) - } - - val SSLRandomSource = getString("ssl.sha1prng-random-source") match { - case "" ⇒ None - case path ⇒ Some(path) - } - - val SSLRandomNumberGenerator = getString("ssl.random-number-generator") match { - case "" ⇒ None - case rng ⇒ Some(rng) - } - - val EnableSSL = { - val enableSSL = getBoolean("ssl.enable") - if (enableSSL) { - if (SSLProtocol.isEmpty) throw new ConfigurationException( - "Configuration option 'akka.remote.netty.ssl.enable is turned on but no protocol is defined in 'akka.remote.netty.ssl.protocol'.") - if (SSLKeyStore.isEmpty && SSLTrustStore.isEmpty) throw new ConfigurationException( - "Configuration option 'akka.remote.netty.ssl.enable is turned on but no key/trust store is defined in 'akka.remote.netty.ssl.key-store' / 'akka.remote.netty.ssl.trust-store'.") - if (SSLKeyStore.isDefined && SSLKeyStorePassword.isEmpty) throw new ConfigurationException( - "Configuration option 'akka.remote.netty.ssl.key-store' is defined but no key-store password is defined in 'akka.remote.netty.ssl.key-store-password'.") - if (SSLTrustStore.isDefined && SSLTrustStorePassword.isEmpty) throw new ConfigurationException( - "Configuration option 'akka.remote.netty.ssl.trust-store' is defined but no trust-store password is defined in 'akka.remote.netty.ssl.trust-store-password'.") - } - enableSSL - } } diff --git a/akka-remote/src/main/scala/akka/remote/transport/AkkaPduCodec.scala b/akka-remote/src/main/scala/akka/remote/transport/AkkaPduCodec.scala new file mode 100644 index 0000000000..76364f2f60 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/AkkaPduCodec.scala @@ -0,0 +1,149 @@ +package akka.remote.transport + +import akka.AkkaException +import akka.actor.{ AddressFromURIString, InternalActorRef, Address, ActorRef } +import akka.remote.RemoteProtocol._ +import akka.remote.transport.AkkaPduCodec._ +import akka.remote.{ RemoteActorRefProvider, RemoteProtocol } +import akka.util.ByteString +import com.google.protobuf.InvalidProtocolBufferException + +class PduCodecException(msg: String, cause: Throwable) extends AkkaException(msg, cause) + +private[remote] object AkkaPduCodec { + + /** + * Trait that represents decoded Akka PDUs + */ + sealed trait AkkaPdu + + case class Associate(cookie: Option[String], origin: Address) extends AkkaPdu + case object Disassociate extends AkkaPdu + case object Heartbeat extends AkkaPdu + case class Message(recipient: InternalActorRef, + recipientAddress: Address, + serializedMessage: MessageProtocol, + sender: Option[ActorRef]) extends AkkaPdu +} + +/** + * A Codec that is able to convert Akka PDUs from and to [[akka.util.ByteString]]s. + */ +private[remote] trait AkkaPduCodec { + + def constructMessagePdu( + localAddress: Address, + recipient: ActorRef, + serializedMessage: MessageProtocol, + senderOption: Option[ActorRef]): ByteString + + def constructAssociate(cookie: Option[String], origin: Address): ByteString + + def constructDisassociate: ByteString + + def constructHeartbeat: ByteString + + def decodePdu(raw: ByteString, provider: RemoteActorRefProvider): AkkaPdu // Effective enough? + +} + +private[remote] object AkkaPduProtobufCodec extends AkkaPduCodec { + + override def constructMessagePdu( + localAddress: Address, + recipient: ActorRef, + serializedMessage: MessageProtocol, + senderOption: Option[ActorRef]): ByteString = { + + val messageBuilder = RemoteMessageProtocol.newBuilder + + messageBuilder.setRecipient(serializeActorRef(recipient.path.address, recipient)) + senderOption foreach { ref ⇒ messageBuilder.setSender(serializeActorRef(localAddress, ref)) } + messageBuilder.setMessage(serializedMessage) + + akkaRemoteProtocolToByteString(AkkaRemoteProtocol.newBuilder().setMessage(messageBuilder.build).build) + } + + override def constructAssociate(cookie: Option[String], origin: Address): ByteString = + constructControlMessagePdu(RemoteProtocol.CommandType.CONNECT, cookie, Some(origin)) + + override val constructDisassociate: ByteString = + constructControlMessagePdu(RemoteProtocol.CommandType.SHUTDOWN, None, None) + + override val constructHeartbeat: ByteString = + constructControlMessagePdu(RemoteProtocol.CommandType.HEARTBEAT, None, None) + + override def decodePdu(raw: ByteString, provider: RemoteActorRefProvider): AkkaPdu = { + try { + val pdu = AkkaRemoteProtocol.parseFrom(raw.toArray) + + if (pdu.hasMessage) { + decodeMessage(pdu.getMessage, provider) + } else if (pdu.hasInstruction) { + decodeControlPdu(pdu.getInstruction) + } else { + throw new PduCodecException("Error decoding Akka PDU: Neither message nor control message were contained", null) + } + } catch { + case e: InvalidProtocolBufferException ⇒ throw new PduCodecException("Decoding PDU failed.", e) + } + } + + private def decodeMessage(msgPdu: RemoteMessageProtocol, provider: RemoteActorRefProvider): Message = { + Message( + recipient = provider.actorFor(provider.rootGuardian, msgPdu.getRecipient.getPath), + recipientAddress = AddressFromURIString(msgPdu.getRecipient.getPath), + serializedMessage = msgPdu.getMessage, + sender = if (msgPdu.hasSender) Some(provider.actorFor(provider.rootGuardian, msgPdu.getSender.getPath)) else None) + } + + private def decodeControlPdu(controlPdu: RemoteControlProtocol): AkkaPdu = { + val cookie = if (controlPdu.hasCookie) Some(controlPdu.getCookie) else None + + controlPdu.getCommandType match { + case CommandType.CONNECT if controlPdu.hasOrigin ⇒ Associate(cookie, decodeAddress(controlPdu.getOrigin)) + case CommandType.SHUTDOWN ⇒ Disassociate + case CommandType.HEARTBEAT ⇒ Heartbeat + case _ ⇒ throw new PduCodecException("Decoding of control PDU failed: format invalid", null) + } + } + + private def decodeAddress(encodedAddress: AddressProtocol): Address = + Address(encodedAddress.getProtocol, encodedAddress.getSystem, encodedAddress.getHostname, encodedAddress.getPort) + + private def constructControlMessagePdu( + code: RemoteProtocol.CommandType, + cookie: Option[String], + origin: Option[Address]): ByteString = { + + val controlMessageBuilder = RemoteControlProtocol.newBuilder() + + controlMessageBuilder.setCommandType(code) + cookie foreach { controlMessageBuilder.setCookie(_) } + for (originAddress ← origin; serialized ← serializeAddress(originAddress)) + controlMessageBuilder.setOrigin(serialized) + + akkaRemoteProtocolToByteString(AkkaRemoteProtocol.newBuilder().setInstruction(controlMessageBuilder.build).build) + } + + private def akkaRemoteProtocolToByteString(pdu: AkkaRemoteProtocol): ByteString = ByteString(pdu.toByteArray) + + private def serializeActorRef(defaultAddress: Address, ref: ActorRef): ActorRefProtocol = { + val fullActorRefString: String = if (ref.path.address.host.isDefined) + ref.path.toString + else + ref.path.toStringWithAddress(defaultAddress) + + ActorRefProtocol.newBuilder.setPath(fullActorRefString).build() + } + + private def serializeAddress(address: Address): Option[AddressProtocol] = { + for (host ← address.host; port ← address.port) yield AddressProtocol.newBuilder + .setHostname(host) + .setPort(port) + .setSystem(address.system) + .setProtocol(address.protocol) + .build() + } + +} diff --git a/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala b/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala new file mode 100644 index 0000000000..d002792f0c --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/AkkaProtocolTransport.scala @@ -0,0 +1,548 @@ +package akka.remote.transport + +import akka.AkkaException +import akka.actor.SupervisorStrategy.Stop +import akka.actor._ +import akka.pattern.pipe +import akka.remote.transport.AkkaPduCodec._ +import akka.remote.transport.AkkaProtocolTransport._ +import akka.remote.transport.AssociationHandle._ +import akka.remote.transport.ProtocolStateActor._ +import akka.remote.transport.Transport._ +import akka.remote.{ PhiAccrualFailureDetector, FailureDetector, RemoteActorRefProvider } +import akka.util.ByteString +import com.typesafe.config.Config +import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS } +import scala.concurrent.{ Future, Promise } +import scala.util.control.NonFatal +import scala.util.{ Success, Failure } +import java.net.URLEncoder +import scala.collection.immutable.Queue + +class AkkaProtocolException(msg: String, cause: Throwable) extends AkkaException(msg, cause) + +private[remote] class AkkaProtocolSettings(config: Config) { + + import config._ + + val FailureDetectorThreshold: Double = getDouble("akka.remoting.failure-detector.threshold") + + val FailureDetectorMaxSampleSize: Int = getInt("akka.remoting.failure-detector.max-sample-size") + + val FailureDetectorStdDeviation: FiniteDuration = + Duration(getMilliseconds("akka.remoting.failure-detector.min-std-deviation"), MILLISECONDS) + + val AcceptableHeartBeatPause: FiniteDuration = + Duration(getMilliseconds("akka.remoting.failure-detector.acceptable-heartbeat-pause"), MILLISECONDS) + + val HeartBeatInterval: FiniteDuration = + Duration(getMilliseconds("akka.remoting.heartbeat-interval"), MILLISECONDS) + + val WaitActivityEnabled: Boolean = getBoolean("akka.remoting.wait-activity-enabled") + + val RequireCookie: Boolean = getBoolean("akka.remoting.require-cookie") + + val SecureCookie: String = getString("akka.remoting.secure-cookie") +} + +private[remote] object AkkaProtocolTransport { + val AkkaScheme: String = "akka" + val AkkaOverhead: Int = 0 //Don't know yet + val UniqueId = new java.util.concurrent.atomic.AtomicInteger(0) + + sealed trait TransportOperation + case class HandlerRegistered(handler: ActorRef) extends TransportOperation + case class AssociateUnderlying(remoteAddress: Address, statusPromise: Promise[Status]) extends TransportOperation + case class ListenUnderlying(listenPromise: Promise[(Address, Promise[ActorRef])]) extends TransportOperation + case object DisassociateUnderlying extends TransportOperation + + def augmentScheme(originalScheme: String): String = s"$originalScheme.$AkkaScheme" + + def augmentScheme(address: Address): Address = address.copy(protocol = augmentScheme(address.protocol)) + + def removeScheme(scheme: String): String = if (scheme.endsWith(s".$AkkaScheme")) + scheme.take(scheme.length - AkkaScheme.length - 1) + else scheme + + def removeScheme(address: Address): Address = address.copy(protocol = removeScheme(address.protocol)) +} + +/** + * Implementation of the Akka protocol as a Trasnsport that wraps an underlying Transport instance. + * + * Features provided by this transport are: + * - Soft-state associations via the use of heartbeats and failure detectors + * - Secure-cookie handling + * - Transparent origin address handling + * - Fire-And-Forget vs. implicit ack based handshake (controllable via wait-activity-enabled configuration option) + * - pluggable codecs to encode and decode Akka PDUs + * + * It is not possible to load this transport dynamically using the configuration of remoting, because it does not + * expose a constructor with [[com.typesafe.config.Config]] and [[akka.actor.ExtendedActorSystem]] parameters. + * This transport is instead loaded automatically by [[akka.remote.Remoting]] to wrap all the dynamically loaded + * transports. + * + * @param wrappedTransport + * the underlying transport that will be used for communication + * @param system + * the actor system + * @param settings + * the configuration options of the Akka protocol + * @param codec + * the codec that will be used to encode/decode Akka PDUs + */ +private[remote] class AkkaProtocolTransport( + private val wrappedTransport: Transport, + private val system: ActorSystem, + private val settings: AkkaProtocolSettings, + private val codec: AkkaPduCodec) extends Transport { + + override val schemeIdentifier: String = augmentScheme(wrappedTransport.schemeIdentifier) + + override def isResponsibleFor(address: Address): Boolean = wrappedTransport.isResponsibleFor(removeScheme(address)) + + //TODO: make this the child of someone more appropriate + private val manager = system.asInstanceOf[ActorSystemImpl].systemActorOf( + Props(new AkkaProtocolManager(wrappedTransport, settings)), + s"akkaprotocolmanager.${wrappedTransport.schemeIdentifier}${UniqueId.getAndIncrement}") + + override val maximumPayloadBytes: Int = wrappedTransport.maximumPayloadBytes - AkkaProtocolTransport.AkkaOverhead + + override def listen: Future[(Address, Promise[ActorRef])] = { + // Prepare a future, and pass its promise to the manager + val listenPromise: Promise[(Address, Promise[ActorRef])] = Promise() + + manager ! ListenUnderlying(listenPromise) + + listenPromise.future + } + + override def associate(remoteAddress: akka.actor.Address): Future[Status] = { + // Prepare a future, and pass its promise to the manager + val statusPromise: Promise[Status] = Promise() + + manager ! AssociateUnderlying(remoteAddress, statusPromise) + + statusPromise.future + } + + override def shutdown(): Unit = { + manager ! PoisonPill + } +} + +private[transport] class AkkaProtocolManager(private val wrappedTransport: Transport, + private val settings: AkkaProtocolSettings) extends Actor { + + import context.dispatcher + + // The AkkaProtocolTransport does not handle the recovery of associations, this task is implemented in the + // remoting itself. Hence the strategy Stop. + override val supervisorStrategy = OneForOneStrategy() { + case NonFatal(_) ⇒ Stop + } + + private var nextId = 0L + + private val associationHandlerPromise: Promise[ActorRef] = Promise() + associationHandlerPromise.future.map { HandlerRegistered(_) } pipeTo self + + @volatile var localAddress: Address = _ + + private var associationHandler: ActorRef = _ + + def receive: Receive = { + case ListenUnderlying(listenPromise) ⇒ + val listenFuture = wrappedTransport.listen + + // - Receive the address and promise from original transport + // - then register ourselves as listeners + // - then complete the exposed promise with the modified contents + listenFuture.onComplete { + case Success((address, wrappedTransportHandlerPromise)) ⇒ + // Register ourselves as the handler for the wrapped transport's listen call + wrappedTransportHandlerPromise.success(self) + localAddress = address + // Pipe the result to the original caller + listenPromise.success((augmentScheme(address), associationHandlerPromise)) + case Failure(reason) ⇒ listenPromise.failure(reason) + } + + case HandlerRegistered(handler) ⇒ + associationHandler = handler + context.become(ready) + + // Block inbound associations until handler is registered + case InboundAssociation(handle) ⇒ handle.disassociate() + } + + private def actorNameFor(remoteAddress: Address): String = { + nextId += 1 + "akkaProtocol-" + URLEncoder.encode(remoteAddress.toString, "utf-8") + "-" + nextId + } + + private def ready: Receive = { + case InboundAssociation(handle) ⇒ + context.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + associationHandler, + settings, + AkkaPduProtobufCodec, + createFailureDetector())), actorNameFor(handle.remoteAddress)) + + case AssociateUnderlying(remoteAddress, statusPromise) ⇒ + context.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + wrappedTransport, + settings, + AkkaPduProtobufCodec, + createFailureDetector())), actorNameFor(remoteAddress)) + } + + private def createFailureDetector(): PhiAccrualFailureDetector = new PhiAccrualFailureDetector( + settings.FailureDetectorThreshold, + settings.FailureDetectorMaxSampleSize, + settings.FailureDetectorStdDeviation, + settings.AcceptableHeartBeatPause, + settings.HeartBeatInterval) + + override def postStop() { + wrappedTransport.shutdown() + } + +} + +private[transport] class AkkaProtocolHandle( + val localAddress: Address, + val remoteAddress: Address, + val readHandlerPromise: Promise[ActorRef], + private val wrappedHandle: AssociationHandle, + private val stateActor: ActorRef, + private val codec: AkkaPduCodec) + extends AssociationHandle { + + // FIXME: This is currently a hack! The caller should not know anything about the format of the Akka protocol + // but here it does. This is temporary and will be fixed. + override def write(payload: ByteString): Boolean = wrappedHandle.write(payload) + + override def disassociate(): Unit = stateActor ! DisassociateUnderlying + +} + +private[transport] object ProtocolStateActor { + sealed trait AssociationState + case object Closed extends AssociationState + case object WaitActivity extends AssociationState + case object Open extends AssociationState + + case object HeartbeatTimer + + sealed trait ProtocolStateData + trait InitialProtocolStateData extends ProtocolStateData + + // Nor the underlying, nor the provided transport is associated + case class OutboundUnassociated(remoteAddress: Address, statusPromise: Promise[Status], transport: Transport) + extends InitialProtocolStateData + + // The underlying transport is associated, but the handshake of the akka protocol is not yet finished + case class OutboundUnderlyingAssociated(statusPromise: Promise[Status], wrappedHandle: AssociationHandle) + extends ProtocolStateData + + // The underlying transport is associated, but the handshake of the akka protocol is not yet finished + case class InboundUnassociated(associationHandler: ActorRef, wrappedHandle: AssociationHandle) + extends InitialProtocolStateData + + // Both transports are associated, but the handler for the handle has not yet been provided + case class AssociatedWaitHandler(handlerFuture: Future[ActorRef], wrappedHandle: AssociationHandle, queue: Queue[ByteString]) + extends ProtocolStateData + + case class HandlerReady(handler: ActorRef, wrappedHandle: AssociationHandle) + extends ProtocolStateData + + case object TimeoutReason +} + +private[transport] class ProtocolStateActor(initialData: InitialProtocolStateData, + private val localAddress: Address, + private val settings: AkkaProtocolSettings, + private val codec: AkkaPduCodec, + private val failureDetector: FailureDetector) + extends Actor with FSM[AssociationState, ProtocolStateData] { + + import ProtocolStateActor._ + import context.dispatcher + + // Outbound case + def this(localAddress: Address, + remoteAddress: Address, + statusPromise: Promise[Status], + transport: Transport, + settings: AkkaProtocolSettings, + codec: AkkaPduCodec, + failureDetector: FailureDetector) = { + this(OutboundUnassociated(remoteAddress, statusPromise, transport), localAddress, settings, codec, failureDetector) + } + + // Inbound case + def this(localAddress: Address, + wrappedHandle: AssociationHandle, + associationHandler: ActorRef, + settings: AkkaProtocolSettings, + codec: AkkaPduCodec, + failureDetector: FailureDetector) = { + this(InboundUnassociated(associationHandler, wrappedHandle), localAddress, settings, codec, failureDetector) + } + + // FIXME: This may break with ClusterActorRefProvider if it does not extends RemoteActorRefProvider + val provider = context.system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider] + + initialData match { + case d: OutboundUnassociated ⇒ + d.transport.associate(removeScheme(d.remoteAddress)) pipeTo self + startWith(Closed, d) + + case d: InboundUnassociated ⇒ + d.wrappedHandle.readHandlerPromise.success(self) + startWith(Closed, d) + } + + when(Closed) { + + // Transport layer events for outbound associations + case Event(s @ Invalid(_), OutboundUnassociated(_, statusPromise, _)) ⇒ + statusPromise.success(s) + stop() + + case Event(s @ Fail(_), OutboundUnassociated(_, statusPromise, _)) ⇒ + statusPromise.success(s) + stop() + + case Event(Ready(wrappedHandle), OutboundUnassociated(_, statusPromise, _)) ⇒ + wrappedHandle.readHandlerPromise.success(self) + sendAssociate(wrappedHandle) + failureDetector.heartbeat() + initTimers() + + if (settings.WaitActivityEnabled) { + goto(WaitActivity) using OutboundUnderlyingAssociated(statusPromise, wrappedHandle) + } else { + goto(Open) using AssociatedWaitHandler(notifyOutboundHandler(wrappedHandle, statusPromise), wrappedHandle, Queue.empty) + } + + // Events for inbound associations + case Event(InboundPayload(p), InboundUnassociated(associationHandler, wrappedHandle)) ⇒ + decodePdu(p) match { + // After receiving Disassociate we MUST NOT send back a Disassociate (loop) + case Disassociate ⇒ stop() + + // Incoming association -- implicitly ACK by a heartbeat + case Associate(cookieOption, origin) ⇒ + if (!settings.RequireCookie || cookieOption.getOrElse("") == settings.SecureCookie) { + sendHeartbeat(wrappedHandle) + + failureDetector.heartbeat() + initTimers() + goto(Open) using AssociatedWaitHandler(notifyInboundHandler(wrappedHandle, origin, associationHandler), wrappedHandle, Queue.empty) + } else { + stop() + } + + // Got a stray message -- explicitly reset the association (force remote endpoint to reassociate) + case _ ⇒ + sendDisassociate(wrappedHandle) + stop() + + } + + case Event(DisassociateUnderlying, _) ⇒ + stop() + + case _ ⇒ stay() + + } + + // Timeout of this state is implicitly handled by the failure detector + when(WaitActivity) { + case Event(Disassociated, OutboundUnderlyingAssociated(_, _)) ⇒ + stop() + + case Event(InboundPayload(p), OutboundUnderlyingAssociated(statusPromise, wrappedHandle)) ⇒ + decodePdu(p) match { + case Disassociate ⇒ + stop() + + // Any other activity is considered an implicit acknowledgement of the association + case Message(recipient, recipientAddress, serializedMessage, senderOption) ⇒ + sendHeartbeat(wrappedHandle) + goto(Open) using + AssociatedWaitHandler(notifyOutboundHandler(wrappedHandle, statusPromise), wrappedHandle, Queue(p)) + + case Heartbeat ⇒ + sendHeartbeat(wrappedHandle) + failureDetector.heartbeat() + goto(Open) using + AssociatedWaitHandler(notifyOutboundHandler(wrappedHandle, statusPromise), wrappedHandle, Queue.empty) + + case _ ⇒ goto(Open) using + AssociatedWaitHandler(notifyOutboundHandler(wrappedHandle, statusPromise), wrappedHandle, Queue.empty) + } + + case Event(HeartbeatTimer, OutboundUnderlyingAssociated(_, wrappedHandle)) ⇒ handleTimers(wrappedHandle) + + } + + when(Open) { + case Event(Disassociated, _) ⇒ + stop() + + case Event(InboundPayload(p), AssociatedWaitHandler(handlerFuture, wrappedHandle, queue)) ⇒ + decodePdu(p) match { + case Disassociate ⇒ + stop() + + case Heartbeat ⇒ failureDetector.heartbeat(); stay() + + case Message(recipient, recipientAddress, serializedMessage, senderOption) ⇒ + // Queue message until handler is registered + stay() using AssociatedWaitHandler(handlerFuture, wrappedHandle, queue :+ p) + + case _ ⇒ stay() + } + + case Event(InboundPayload(p), HandlerReady(handler, _)) ⇒ + decodePdu(p) match { + case Disassociate ⇒ + stop() + + case Heartbeat ⇒ failureDetector.heartbeat(); stay() + + case Message(recipient, recipientAddress, serializedMessage, senderOption) ⇒ + handler ! InboundPayload(p) + stay() + + case _ ⇒ stay() + } + + case Event(HeartbeatTimer, AssociatedWaitHandler(_, wrappedHandle, _)) ⇒ handleTimers(wrappedHandle) + case Event(HeartbeatTimer, HandlerReady(_, wrappedHandle)) ⇒ handleTimers(wrappedHandle) + + case Event(DisassociateUnderlying, HandlerReady(handler, wrappedHandle)) ⇒ + sendDisassociate(wrappedHandle) + stop() + + case Event(HandlerRegistered(ref), AssociatedWaitHandler(_, wrappedHandle, queue)) ⇒ + queue.foreach { ref ! InboundPayload(_) } + stay() using HandlerReady(ref, wrappedHandle) + } + + private def initTimers(): Unit = { + setTimer("heartbeat-timer", HeartbeatTimer, settings.HeartBeatInterval, repeat = true) + } + + private def handleTimers(wrappedHandle: AssociationHandle): State = { + if (failureDetector.isAvailable) { + sendHeartbeat(wrappedHandle) + stay() + } else { + // send disassociate just to be sure + sendDisassociate(wrappedHandle) + stop(FSM.Failure(TimeoutReason)) + } + } + + override def postStop(): Unit = { + cancelTimer("heartbeat-timer") + super.postStop() // Pass to onTermination + } + + onTermination { + case StopEvent(_, _, OutboundUnassociated(remoteAddress, statusPromise, transport)) ⇒ + statusPromise.trySuccess(Fail(new AkkaProtocolException("Transport disassociated before handshake finished", null))) + + case StopEvent(reason, _, OutboundUnderlyingAssociated(statusPromise, wrappedHandle)) ⇒ + val msg = reason match { + case FSM.Failure(TimeoutReason) ⇒ "No response from remote. Handshake timed out." + case _ ⇒ "Remote endpoint disassociated before handshake finished" + } + statusPromise.trySuccess(Fail(new AkkaProtocolException(msg, null))) + wrappedHandle.disassociate() + + case StopEvent(_, _, AssociatedWaitHandler(handlerFuture, wrappedHandle, queue)) ⇒ + // Invalidate exposed but still unfinished promise. The underlying association disappeared, so after + // registration immediately signal a disassociate + handlerFuture.onSuccess { + case handler: ActorRef ⇒ handler ! Disassociated + } + + case StopEvent(_, _, HandlerReady(handler, wrappedHandle)) ⇒ + handler ! Disassociated + wrappedHandle.disassociate() + + case StopEvent(_, _, InboundUnassociated(_, wrappedHandle)) ⇒ + wrappedHandle.disassociate() + } + + private def notifyOutboundHandler(wrappedHandle: AssociationHandle, statusPromise: Promise[Status]): Future[ActorRef] = { + val readHandlerPromise: Promise[ActorRef] = Promise() + readHandlerPromise.future.map { HandlerRegistered(_) } pipeTo self + + val exposedHandle = + new AkkaProtocolHandle( + augmentScheme(localAddress), + augmentScheme(wrappedHandle.remoteAddress), + readHandlerPromise, + wrappedHandle, + self, + codec) + + statusPromise.success(Ready(exposedHandle)) + readHandlerPromise.future + } + + private def notifyInboundHandler(wrappedHandle: AssociationHandle, originAddress: Address, associationHandler: ActorRef): Future[ActorRef] = { + val readHandlerPromise: Promise[ActorRef] = Promise() + readHandlerPromise.future.map { HandlerRegistered(_) } pipeTo self + + val exposedHandle = + new AkkaProtocolHandle( + augmentScheme(localAddress), + augmentScheme(originAddress), + readHandlerPromise, + wrappedHandle, + self, + codec) + + associationHandler ! InboundAssociation(exposedHandle) + readHandlerPromise.future + } + + // Helper method for exception translation + private def ape[T](errorMsg: String)(body: ⇒ T): T = try body catch { + case NonFatal(e) ⇒ throw new AkkaProtocolException(errorMsg, e) + } + + private def decodePdu(pdu: ByteString): AkkaPdu = ape("Error while decoding incoming Akka PDU of length: " + pdu.length) { + codec.decodePdu(pdu, provider) + } + + // Neither heartbeats neither disassociate cares about backing off if write fails: + // - Missing heartbeats are not critical + // - Disassociate messages are not guaranteed anyway + private def sendHeartbeat(wrappedHandle: AssociationHandle): Unit = ape("Error writing HEARTBEAT to transport") { + wrappedHandle.write(codec.constructHeartbeat) + } + + private def sendDisassociate(wrappedHandle: AssociationHandle): Unit = ape("Error writing DISASSOCIATE to transport") { + wrappedHandle.write(codec.constructDisassociate) + } + + // Associate should be the first message, so backoff is not needed + private def sendAssociate(wrappedHandle: AssociationHandle): Unit = ape("Error writing ASSOCIATE to transport") { + val cookie = if (settings.RequireCookie) Some(settings.SecureCookie) else None + wrappedHandle.write(codec.constructAssociate(cookie, localAddress)) + } + +} diff --git a/akka-remote/src/main/scala/akka/remote/transport/Transport.scala b/akka-remote/src/main/scala/akka/remote/transport/Transport.scala new file mode 100644 index 0000000000..3ef6908f49 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/Transport.scala @@ -0,0 +1,207 @@ +package akka.remote.transport + +import concurrent.{ Promise, Future } +import akka.actor.{ ActorRef, Address } +import akka.util.ByteString + +object Transport { + + /** + * Represents fine grained status of an association attempt. + */ + sealed trait Status + + /** + * Indicates that the association setup request is invalid, and it is impossible to recover (malformed IP address, + * hostname, etc.). Invalid association requests are impossible to recover. + */ + case class Invalid(cause: Throwable) extends Status + + /** + * The association setup has failed, but no information can be provided about the probability of the success of a + * setup retry. + * + * @param cause Cause of the failure + */ + case class Fail(cause: Throwable) extends Status + + /** + * No detectable errors happened during association. Generally a status of Ready does not guarantee that the + * association was successful. For example in the case of UDP, the transport MAY return Ready immediately after an + * association setup was requested. + * + * @param association + * The handle for the created association. + */ + case class Ready(association: AssociationHandle) extends Status + + /** + * Message sent to an actor registered to a transport (via the Promise returned by + * [[akka.remote.transport.Transport.listen]]) when an inbound association request arrives. + * + * @param association + * The handle for the inbound association. + */ + case class InboundAssociation(association: AssociationHandle) + +} + +/** + * An SPI layer for implementing asynchronous transport mechanisms. The transport is responsible for initializing the + * underlying transport mechanism and setting up logical links between transport entities. + * + * Transport implementations that are loaded dynamically by the remoting must have a constructor that accepts a + * [[com.typesafe.config.Config]] and an [[akka.actor.ExtendedActorSystem]] as parameters. + */ +trait Transport { + import akka.remote.transport.Transport._ + + /** + * Returns a string that will be used as the scheme part of the URLs corresponding to this transport + * @return the scheme string + */ + def schemeIdentifier: String + + /** + * A function that decides whether the specific transport instance is responsible for delivering + * to a given address. The function must be thread-safe and non-blocking. + * + * The purpose of this function is to resolve cases when the scheme part of an URL is not enough to resolve + * the correct transport i.e. multiple instances of the same transport implementation are loaded. These cases arise when + * - the same transport, but with different configurations is used for different remote systems + * - a transport is able to serve one address only (hardware protocols, e.g. Serial port) and multiple + * instances are needed to be loaded for different endpoints. + * + * @return whether the transport instance is responsible to serve communications to the given address. + */ + def isResponsibleFor(address: Address): Boolean + + /** + * Defines the maximum size of payload this transport is able to deliver. All transports MUST support at least + * 32kBytes (32000 octets) of payload, but some MAY support larger sizes. + * @return + */ + def maximumPayloadBytes: Int + + /** + * Asynchronously attempts to setup the transport layer to listen and accept incoming associations. The result of the + * attempt is wrapped by a Future returned by this method. The pair contained in the future contains a Promise for an + * ActorRef. By completing this Promise with an ActorRef, that ActorRef becomes responsible for handling incoming + * associations. Until the Promise is not completed, no associations are processed. + * + * @return + * A Future containing a pair of the bound local address and a Promise of an ActorRef that must be fulfilled + * by the consumer of the future. + */ + def listen: Future[(Address, Promise[ActorRef])] + + /** + * Asynchronously opens a logical duplex link between two Transport Entities over a network. It could be backed by a + * real transport-layer connection (TCP), more lightweight connections provided over datagram protocols (UDP with + * additional services), substreams of multiplexed connections (SCTP) or physical links (serial port). + * + * This call returns a fine-grained status indication of the attempt wrapped in a Future. See + * [[akka.remote.transport.Transport.Status]] for details. + * + * @param remoteAddress + * The address of the remote transport entity. + * @return + * A status instance representing failure or a success containing an [[akka.remote.transport.AssociationHandle]] + */ + def associate(remoteAddress: Address): Future[Status] + + /** + * Shuts down the transport layer and releases all the corresponding resources. Shutdown is asynchronous, may be + * called multiple times and does not return a success indication. + * + * The transport SHOULD try flushing pending writes before becoming completely closed. + */ + def shutdown(): Unit + +} + +object AssociationHandle { + + /** + * Trait for events that the registered actor for an [[akka.remote.transport.AssociationHandle]] might receive. + */ + sealed trait AssociationEvent + + /** + * Message sent to the actor registered to an association (via the Promise returned by + * [[akka.remote.transport.AssociationHandle.readHandlerPromise]]) when an inbound payload arrives. + * + * @param payload + * The raw bytes that were sent by the remote endpoint. + */ + case class InboundPayload(payload: ByteString) extends AssociationEvent + + /** + * Message sent to te actor registered to an association + */ + case object Disassociated extends AssociationEvent + +} + +/** + * An SPI layer for abstracting over logical links (associations) created by [[akka.remote.transport.Transport]]. + * Handles are responsible for providing an API for sending and receiving from the underlying channel. + * + * To register an actor for processing incoming payload data, the actor must be registered by completing the Promise + * returned by [[akka.remote.transport.AssociationHandle#readHandlerPromise]]. Incoming data is not processed until + * this registration takes place. + */ +trait AssociationHandle { + + /** + * Address of the local endpoint. + * + * @return + * Address of the local endpoint. + */ + def localAddress: Address + + /** + * Address of the remote endpoint. + * + * @return + * Address of the remote endpoint. + */ + def remoteAddress: Address + + /** + * The Promise returned by this call must be completed with an [[akka.actor.ActorRef]] to register an actor + * responsible for handling incoming payload. + * + * @return + * Promise of the ActorRef of the actor responsible for handling incoming data. + */ + def readHandlerPromise: Promise[ActorRef] + + /** + * Asynchronously sends the specified payload to the remote endpoint. This method must be thread-safe as it might + * be called from different threads. This method must not block. + * + * Writes guarantee ordering of messages, but not their reception. The call to write returns with + * a Boolean indicating if the channel was ready for writes or not. A return value of false indicates that the + * channel is not yet ready for delivery (e.g.: the write buffer is full) and the sender needs to wait + * until the channel becomes ready again. Returning false also means that the current write was dropped (this is + * guaranteed to ensure duplication-free delivery). + * + * @param payload + * The payload to be delivered to the remote endpoint. + * @return + * Boolean indicating the availability of the association for subsequent writes. + */ + def write(payload: ByteString): Boolean + + /** + * Closes the underlying transport link, if needed. Some transport may not need an explicit teardown (UDP) and + * some transports may not support it (hardware connections). Remote endpoint of the channel or connection ''may'' + * be notified, but this is not guaranteed. + * + */ + def disassociate(): Unit + +} + diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/NettyHelpers.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyHelpers.scala new file mode 100644 index 0000000000..0c768e9902 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyHelpers.scala @@ -0,0 +1,78 @@ +package akka.remote.transport.netty + +import akka.AkkaException +import java.nio.channels.ClosedChannelException +import org.jboss.netty.channel._ +import scala.util.control.NonFatal + +private[netty] trait NettyHelpers { + + protected def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {} + + protected def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {} + + protected def onOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {} + + protected def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {} + + protected def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {} + + final protected def transformException(ctx: ChannelHandlerContext, ev: ExceptionEvent): Unit = { + val cause = if (ev.getCause ne null) ev.getCause else new AkkaException("Unknown cause") + cause match { + case _: ClosedChannelException ⇒ // Ignore + case NonFatal(e) ⇒ onException(ctx, ev) + case e: Throwable ⇒ throw e // Rethrow fatals + } + } +} + +private[netty] trait NettyServerHelpers extends SimpleChannelUpstreamHandler with NettyHelpers { + + final override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = { + super.messageReceived(ctx, e) + onMessage(ctx, e) + } + + final override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = transformException(ctx, e) + + final override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelConnected(ctx, e) + onConnect(ctx, e) + } + + final override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelOpen(ctx, e) + onOpen(ctx, e) + } + + final override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelDisconnected(ctx, e) + onDisconnect(ctx, e) + } +} + +private[netty] trait NettyClientHelpers extends SimpleChannelHandler with NettyHelpers { + final override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = { + super.messageReceived(ctx, e) + onMessage(ctx, e) + } + + final override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = transformException(ctx, e) + + final override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelConnected(ctx, e) + onConnect(ctx, e) + } + + final override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelOpen(ctx, e) + onOpen(ctx, e) + } + + final override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = { + super.channelDisconnected(ctx, e) + onDisconnect(ctx, e) + } +} + diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala new file mode 100644 index 0000000000..b83c49df1d --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala @@ -0,0 +1,366 @@ +package akka.remote.transport.netty + +import akka.ConfigurationException +import akka.actor.{ Address, ExtendedActorSystem, ActorRef } +import akka.event.Logging +import akka.remote.netty.{ SslSettings, NettySSLSupport, DefaultDisposableChannelGroup } +import akka.remote.transport.Transport._ +import akka.remote.transport.netty.NettyTransportSettings.{ Udp, Tcp, Mode } +import akka.remote.transport.{ AssociationHandle, Transport } +import com.typesafe.config.Config +import java.net.{ UnknownHostException, SocketAddress, InetAddress, InetSocketAddress } +import java.util.concurrent.{ ConcurrentHashMap, Executor, Executors } +import org.jboss.netty.bootstrap.{ ConnectionlessBootstrap, Bootstrap, ClientBootstrap, ServerBootstrap } +import org.jboss.netty.buffer.ChannelBuffer +import org.jboss.netty.channel._ +import org.jboss.netty.channel.group.{ ChannelGroupFuture, ChannelGroupFutureListener } +import org.jboss.netty.channel.socket.nio.{ NioDatagramChannelFactory, NioServerSocketChannelFactory, NioClientSocketChannelFactory } +import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender } +import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS } +import scala.concurrent.{ ExecutionContext, Promise, Future } +import scala.util.Random +import scala.util.control.NonFatal + +object NettyTransportSettings { + sealed trait Mode + case object Tcp extends Mode { override def toString = "tcp" } + case object Udp extends Mode { override def toString = "udp" } +} + +class NettyTransportException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) + +class NettyTransportSettings(config: Config) { + + import config._ + + val TransportMode: Mode = getString("transport-protocol") match { + case "tcp" ⇒ Tcp + case "udp" ⇒ Udp + case s @ _ ⇒ throw new ConfigurationException("Unknown transport: " + s) + } + + val EnableSsl: Boolean = if (getBoolean("enable-ssl") && TransportMode == Udp) + throw new ConfigurationException("UDP transport does not support SSL") + else getBoolean("enable-ssl") + + val UseDispatcherForIo: Option[String] = getString("use-dispatcher-for-io") match { + case "" | null ⇒ None + case dispatcher ⇒ Some(dispatcher) + } + + private[this] def optionSize(s: String): Option[Int] = getBytes(s).toInt match { + case 0 ⇒ None + case x if x < 0 ⇒ + throw new ConfigurationException(s"Setting '$s' must be 0 or positive (and fit in an Int)") + case other ⇒ Some(other) + } + + val ConnectionTimeout: FiniteDuration = Duration(getMilliseconds("connection-timeout"), MILLISECONDS) + + val WriteBufferHighWaterMark: Option[Int] = optionSize("write-buffer-high-water-mark") + + val WriteBufferLowWaterMark: Option[Int] = optionSize("write-buffer-low-water-mark") + + val SendBufferSize: Option[Int] = optionSize("send-buffer-size") + + val ReceiveBufferSize: Option[Int] = optionSize("receive-buffer-size") + + val Backlog: Int = getInt("backlog") + + val Hostname: String = getString("hostname") match { + case "" ⇒ InetAddress.getLocalHost.getHostAddress + case value ⇒ value + } + + @deprecated("WARNING: This should only be used by professionals.", "2.0") + val PortSelector: Int = getInt("port") + + val SslSettings: Option[SslSettings] = if (EnableSsl) Some(new SslSettings(config.getConfig("ssl"))) else None + +} + +trait HasTransport { + protected val transport: NettyTransport +} + +trait CommonHandlers extends NettyHelpers with HasTransport { + import transport.executionContext + + final override def onOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = transport.channels.add(e.getChannel) + + protected def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle + + protected def registerReader(channel: Channel, readerRef: ActorRef, msg: ChannelBuffer, remoteSocketAddress: InetSocketAddress): Unit + + final protected def init(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer)(op: (AssociationHandle ⇒ Any)): Unit = { + import transport._ + (addressFromSocketAddress(channel.getLocalAddress), addressFromSocketAddress(remoteSocketAddress)) match { + case (Some(localAddress), Some(remoteAddress)) ⇒ + val handle = createHandle(channel, localAddress, remoteAddress) + handle.readHandlerPromise.future.onSuccess { + case readerRef: ActorRef ⇒ + registerReader(channel, readerRef, msg, remoteSocketAddress.asInstanceOf[InetSocketAddress]) + channel.setReadable(true) + } + op(handle) + + case _ ⇒ NettyTransport.gracefulClose(channel) + } + } +} + +abstract class ServerHandler(protected final val transport: NettyTransport, + private final val associationHandlerFuture: Future[ActorRef]) + extends NettyServerHelpers with CommonHandlers with HasTransport { + import transport.executionContext + + final protected def initInbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = { + channel.setReadable(false) + associationHandlerFuture.onSuccess { + case ref: ActorRef ⇒ init(channel, remoteSocketAddress, msg) { ref ! InboundAssociation(_) } + } + } + +} + +abstract class ClientHandler(protected final val transport: NettyTransport, + private final val statusPromise: Promise[Status]) + extends NettyClientHelpers with CommonHandlers with HasTransport { + + final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = { + channel.setReadable(false) + init(channel, remoteSocketAddress, msg) { handle ⇒ statusPromise.success(Ready(handle)) } + } + +} + +private[transport] object NettyTransport { + val FrameLengthFieldLength = 4 + def gracefulClose(channel: Channel): Unit = channel.disconnect().addListener(ChannelFutureListener.CLOSE) + +} + +class NettyTransport(private val settings: NettyTransportSettings, private val system: ExtendedActorSystem) extends Transport { + + def this(system: ExtendedActorSystem, conf: Config) = this(new NettyTransportSettings(conf), system) + + import NettyTransport._ + import settings._ + implicit val executionContext: ExecutionContext = system.dispatcher + + override val schemeIdentifier: String = TransportMode + (if (EnableSsl) ".ssl" else "") + override val maximumPayloadBytes: Int = 32000 + + private final val isDatagram: Boolean = TransportMode == Udp + + @volatile private var localAddress: Address = _ + @volatile private var masterChannel: Channel = _ + + private val log = Logging(system, this.getClass) + + final val udpConnectionTable = new ConcurrentHashMap[SocketAddress, ActorRef]() + + val channels = new DefaultDisposableChannelGroup("netty-transport-" + Random.nextString(20)) + + private def executor: Executor = UseDispatcherForIo match { + case Some(dispatcherName) ⇒ system.dispatchers.lookup(dispatcherName) + case None ⇒ Executors.newCachedThreadPool() // FIXME: apply patch from #2659 when available + } + + private val clientChannelFactory: ChannelFactory = TransportMode match { + case Tcp ⇒ new NioClientSocketChannelFactory(executor, executor) + case Udp ⇒ new NioDatagramChannelFactory(executor) + } + + private val serverChannelFactory: ChannelFactory = TransportMode match { + case Tcp ⇒ new NioServerSocketChannelFactory(executor, executor) + case Udp ⇒ new NioDatagramChannelFactory(executor) + } + + private def newPipeline: DefaultChannelPipeline = { + val pipeline = new DefaultChannelPipeline + + if (!isDatagram) { + pipeline.addLast("FrameDecoder", new LengthFieldBasedFrameDecoder( + maximumPayloadBytes, + 0, + FrameLengthFieldLength, + 0, + FrameLengthFieldLength, // Strip the header + true)) + pipeline.addLast("FrameEncoder", new LengthFieldPrepender(FrameLengthFieldLength)) + } + + pipeline + } + + private val associationHandlerPromise: Promise[ActorRef] = Promise() + private val serverPipelineFactory: ChannelPipelineFactory = new ChannelPipelineFactory { + override def getPipeline: ChannelPipeline = { + val pipeline = newPipeline + if (EnableSsl) pipeline.addFirst("SslHandler", NettySSLSupport(settings.SslSettings.get, log, false)) + val handler = if (isDatagram) new UdpServerHandler(NettyTransport.this, associationHandlerPromise.future) + else new TcpServerHandler(NettyTransport.this, associationHandlerPromise.future) + pipeline.addLast("ServerHandler", handler) + pipeline + } + } + + private def clientPipelineFactory(statusPromise: Promise[Status]): ChannelPipelineFactory = new ChannelPipelineFactory { + override def getPipeline: ChannelPipeline = { + val pipeline = newPipeline + if (EnableSsl) pipeline.addFirst("SslHandler", NettySSLSupport(settings.SslSettings.get, log, true)) + val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, statusPromise) + else new TcpClientHandler(NettyTransport.this, statusPromise) + pipeline.addLast("clienthandler", handler) + pipeline + } + } + + private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = { + bootstrap.setPipelineFactory(pipelineFactory) + bootstrap.setOption("backlog", settings.Backlog) + bootstrap.setOption("tcpNoDelay", true) + bootstrap.setOption("child.keepAlive", true) + bootstrap.setOption("reuseAddress", true) + if (isDatagram) bootstrap.setOption("receiveBufferSizePredictorFactory", new FixedReceiveBufferSizePredictorFactory(ReceiveBufferSize.get)) + settings.ReceiveBufferSize.foreach(sz ⇒ bootstrap.setOption("receiveBufferSize", sz)) + settings.SendBufferSize.foreach(sz ⇒ bootstrap.setOption("sendBufferSize", sz)) + settings.WriteBufferHighWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferHighWaterMark", sz)) + settings.WriteBufferLowWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferLowWaterMark", sz)) + bootstrap + } + + private val inboundBootstrap: Bootstrap = settings.TransportMode match { + case Tcp ⇒ setupBootstrap(new ServerBootstrap(serverChannelFactory), serverPipelineFactory) + case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory) + } + + private def outboundBootstrap(statusPromise: Promise[Status]): ClientBootstrap = { + val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(statusPromise)) + bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis) + bootstrap + } + + override def isResponsibleFor(address: Address): Boolean = true //TODO: Add configurable subnet filtering + + def addressFromSocketAddress(addr: SocketAddress, + systemName: Option[String] = None, + hostName: Option[String] = None): Option[Address] = { + addr match { + case sa: InetSocketAddress ⇒ + Some(Address(schemeIdentifier, systemName.getOrElse(""), hostName.getOrElse(sa.getHostString), sa.getPort)) + + case _ ⇒ None + } + } + + def addressToSocketAddress(addr: Address): InetSocketAddress = + new InetSocketAddress(InetAddress.getByName(addr.host.get), addr.port.get) + + override def listen: Future[(Address, Promise[ActorRef])] = { + val listenPromise: Promise[(Address, Promise[ActorRef])] = Promise() + + try { + masterChannel = inboundBootstrap match { + case b: ServerBootstrap ⇒ b.bind(new InetSocketAddress(InetAddress.getByName(settings.Hostname), settings.PortSelector)) + case b: ConnectionlessBootstrap ⇒ + b.bind(new InetSocketAddress(InetAddress.getByName(settings.Hostname), settings.PortSelector)) + } + + // Block reads until a handler actor is registered + masterChannel.setReadable(false) + channels.add(masterChannel) + + addressFromSocketAddress(masterChannel.getLocalAddress, Some(system.name), Some(settings.Hostname)) match { + case Some(address) ⇒ + val handlerPromise: Promise[ActorRef] = Promise() + listenPromise.success((address, handlerPromise)) + localAddress = address + handlerPromise.future.onSuccess { + case ref: ActorRef ⇒ + associationHandlerPromise.success(ref) + masterChannel.setReadable(true) + } + + case None ⇒ + listenPromise.failure( + new NettyTransportException(s"Unknown local address type ${masterChannel.getLocalAddress.getClass}", null)) + } + + } catch { + case NonFatal(e) ⇒ listenPromise.failure(e) + } + + listenPromise.future + } + + override def associate(remoteAddress: Address): Future[Status] = { + val statusPromise: Promise[Status] = Promise() + + if (!masterChannel.isBound) statusPromise.success(Fail(new NettyTransportException("Transport is not bound", null))) + + try { + if (!isDatagram) { + val connectFuture = outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress)) + + connectFuture.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) + statusPromise.failure(future.getCause) + else if (future.isCancelled) + statusPromise.failure(new NettyTransportException("Connection was cancelled", null)) + + } + }) + + } else { + val connectFuture = outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress)) + + connectFuture.addListener(new ChannelFutureListener { + def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) + statusPromise.failure(future.getCause) + else if (future.isCancelled) + statusPromise.failure(new NettyTransportException("Connection was cancelled", null)) + else { + val handle: UdpAssociationHandle = new UdpAssociationHandle(localAddress, remoteAddress, future.getChannel, NettyTransport.this) + + future.getChannel.getRemoteAddress match { + case addr: InetSocketAddress ⇒ + statusPromise.success(Ready(handle)) + handle.readHandlerPromise.future.onSuccess { + case ref: ActorRef ⇒ udpConnectionTable.put(addr, ref) + } + case a @ _ ⇒ statusPromise.success(Fail( + new NettyTransportException("Unknown remote address type " + a.getClass, null))) + } + } + } + }) + } + + } catch { + + case e @ (_: UnknownHostException | _: SecurityException | _: IllegalArgumentException) ⇒ + statusPromise.success(Invalid(e)) + + case NonFatal(e) ⇒ + statusPromise.success(Fail(e)) + } + + statusPromise.future + } + + override def shutdown(): Unit = { + channels.unbind() + channels.disconnect().addListener(new ChannelGroupFutureListener { + def operationComplete(future: ChannelGroupFuture) { + channels.close() + inboundBootstrap.releaseExternalResources() + } + }) + } + +} + diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala new file mode 100644 index 0000000000..4bc4920e0c --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala @@ -0,0 +1,73 @@ +package akka.remote.transport.netty + +import akka.actor.{ Address, ActorRef } +import akka.remote.transport.AssociationHandle +import akka.remote.transport.AssociationHandle.{ Disassociated, InboundPayload } +import akka.remote.transport.Transport.Status +import akka.util.ByteString +import java.net.InetSocketAddress +import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer } +import org.jboss.netty.channel._ +import scala.concurrent.{ Future, Promise } + +object ChannelLocalActor extends ChannelLocal[Option[ActorRef]] { + override def initialValue(channel: Channel): Option[ActorRef] = None + def trySend(channel: Channel, msg: Any): Unit = get(channel) foreach { _ ! msg } +} + +trait TcpHandlers extends CommonHandlers with HasTransport { + + import ChannelLocalActor._ + + override def registerReader(channel: Channel, + readerRef: ActorRef, + msg: ChannelBuffer, + remoteSocketAddress: InetSocketAddress): Unit = ChannelLocalActor.set(channel, Some(readerRef)) + + override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle = + new TcpAssociationHandle(localAddress, remoteAddress, channel) + + override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + trySend(e.getChannel, Disassociated) + } + + override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) { + trySend(e.getChannel, InboundPayload(ByteString(e.getMessage.asInstanceOf[ChannelBuffer].array()))) + } + + override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent) { + trySend(e.getChannel, Disassociated) + e.getChannel.close() // No graceful close here -- force TCP reset + } +} + +class TcpServerHandler(_transport: NettyTransport, _associationHandlerFuture: Future[ActorRef]) + extends ServerHandler(_transport, _associationHandlerFuture) with TcpHandlers { + + override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + initInbound(e.getChannel, e.getChannel.getRemoteAddress, null) + } + +} + +class TcpClientHandler(_transport: NettyTransport, _statusPromise: Promise[Status]) + extends ClientHandler(_transport, _statusPromise) with TcpHandlers { + + override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null) + } + +} + +class TcpAssociationHandle(val localAddress: Address, val remoteAddress: Address, private val channel: Channel) + extends AssociationHandle { + + override val readHandlerPromise: Promise[ActorRef] = Promise() + + override def write(payload: ByteString): Boolean = if (channel.isWritable && channel.isOpen) { + channel.write(ChannelBuffers.wrappedBuffer(payload.asByteBuffer)) + true + } else false + + override def disassociate(): Unit = NettyTransport.gracefulClose(channel) +} diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala new file mode 100644 index 0000000000..8e92a57980 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala @@ -0,0 +1,82 @@ +package akka.remote.transport.netty + +import akka.actor.{ ActorRef, Address } +import akka.remote.transport.AssociationHandle +import akka.remote.transport.AssociationHandle.InboundPayload +import akka.remote.transport.Transport.Status +import akka.util.ByteString +import java.net.{ SocketAddress, InetAddress, InetSocketAddress } +import org.jboss.netty.buffer.{ ChannelBuffer, ChannelBuffers } +import org.jboss.netty.channel._ +import scala.concurrent.{ Future, Promise } + +trait UdpHandlers extends CommonHandlers with HasTransport { + + override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle = + new UdpAssociationHandle(localAddress, remoteAddress, channel, transport) + + override def registerReader(channel: Channel, + readerRef: ActorRef, + msg: ChannelBuffer, + remoteSocketAddress: InetSocketAddress): Unit = { + val oldReader: ActorRef = transport.udpConnectionTable.putIfAbsent(remoteSocketAddress, readerRef) + if (oldReader ne null) { + throw new NettyTransportException(s"Reader $readerRef attempted to register for remote address $remoteSocketAddress" + + s" but $oldReader was already registered.", null) + } + readerRef ! InboundPayload(ByteString(msg.array())) + } + + override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) { + if (e.getRemoteAddress.isInstanceOf[InetSocketAddress]) { + val inetSocketAddress: InetSocketAddress = e.getRemoteAddress.asInstanceOf[InetSocketAddress] + if (!transport.udpConnectionTable.containsKey(inetSocketAddress)) { + e.getChannel.setReadable(false) + initUdp(e.getChannel, e.getRemoteAddress, e.getMessage.asInstanceOf[ChannelBuffer]) + + } else { + val reader = transport.udpConnectionTable.get(inetSocketAddress) + reader ! InboundPayload(ByteString(e.getMessage.asInstanceOf[ChannelBuffer].array())) + } + } + } + + def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit +} + +class UdpServerHandler(_transport: NettyTransport, _associationHandlerFuture: Future[ActorRef]) + extends ServerHandler(_transport, _associationHandlerFuture) with UdpHandlers { + + override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = + initInbound(channel, remoteSocketAddress, msg) +} + +class UdpClientHandler(_transport: NettyTransport, _statusPromise: Promise[Status]) + extends ClientHandler(_transport, _statusPromise) with UdpHandlers { + + override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = + initOutbound(channel, remoteSocketAddress, msg) +} + +class UdpAssociationHandle(val localAddress: Address, + val remoteAddress: Address, + private val channel: Channel, + private val transport: NettyTransport) extends AssociationHandle { + + override val readHandlerPromise: Promise[ActorRef] = Promise() + + override def write(payload: ByteString): Boolean = { + if (!channel.isConnected) + channel.connect(new InetSocketAddress(InetAddress.getByName(remoteAddress.host.get), remoteAddress.port.get)) + + if (channel.isWritable && channel.isOpen) { + channel.write(ChannelBuffers.wrappedBuffer(payload.asByteBuffer)) + true + } else false + } + + override def disassociate(): Unit = { + channel.close() + transport.udpConnectionTable.remove(transport.addressToSocketAddress(remoteAddress)) + } +} \ No newline at end of file diff --git a/akka-remote/src/test/scala/akka/remote/AccrualFailureDetectorSpec.scala b/akka-remote/src/test/scala/akka/remote/AccrualFailureDetectorSpec.scala new file mode 100644 index 0000000000..3daa799bbf --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/AccrualFailureDetectorSpec.scala @@ -0,0 +1,215 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ + +package akka.remote + +import akka.testkit.AkkaSpec +import scala.collection.immutable.TreeMap +import scala.concurrent.duration._ +import akka.remote.FailureDetector.Clock + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class AccrualFailureDetectorSpec extends AkkaSpec("akka.loglevel = INFO") { + + "An AccrualFailureDetector" must { + + def fakeTimeGenerator(timeIntervals: Seq[Long]): Clock = new Clock { + @volatile var times = timeIntervals.tail.foldLeft(List[Long](timeIntervals.head))((acc, c) ⇒ acc ::: List[Long](acc.last + c)) + override def apply(): Long = { + val currentTime = times.head + times = times.tail + currentTime + } + } + + def createFailureDetector( + threshold: Double = 8.0, + maxSampleSize: Int = 1000, + minStdDeviation: FiniteDuration = 10.millis, + acceptableLostDuration: FiniteDuration = Duration.Zero, + firstHeartbeatEstimate: FiniteDuration = 1.second, + clock: Clock = FailureDetector.defaultClock) = + new PhiAccrualFailureDetector( + threshold, + maxSampleSize, + minStdDeviation, + acceptableLostDuration, + firstHeartbeatEstimate = firstHeartbeatEstimate)(clock = clock) + + "use good enough cumulative distribution function" in { + val fd = createFailureDetector() + fd.cumulativeDistributionFunction(0.0, 0, 1) must be(0.5 plusOrMinus (0.001)) + fd.cumulativeDistributionFunction(0.6, 0, 1) must be(0.7257 plusOrMinus (0.001)) + fd.cumulativeDistributionFunction(1.5, 0, 1) must be(0.9332 plusOrMinus (0.001)) + fd.cumulativeDistributionFunction(2.0, 0, 1) must be(0.97725 plusOrMinus (0.01)) + fd.cumulativeDistributionFunction(2.5, 0, 1) must be(0.9379 plusOrMinus (0.1)) + fd.cumulativeDistributionFunction(3.5, 0, 1) must be(0.99977 plusOrMinus (0.1)) + fd.cumulativeDistributionFunction(4.0, 0, 1) must be(0.99997 plusOrMinus (0.1)) + + for (x :: y :: Nil ← (0.0 to 4.0 by 0.1).toList.sliding(2)) { + fd.cumulativeDistributionFunction(x, 0, 1) must be < ( + fd.cumulativeDistributionFunction(y, 0, 1)) + } + + fd.cumulativeDistributionFunction(2.2, 2.0, 0.3) must be(0.7475 plusOrMinus (0.001)) + } + + "return realistic phi values" in { + val fd = createFailureDetector() + val test = TreeMap(0 -> 0.0, 500 -> 0.1, 1000 -> 0.3, 1200 -> 1.6, 1400 -> 4.7, 1600 -> 10.8, 1700 -> 15.3) + for ((timeDiff, expectedPhi) ← test) { + fd.phi(timeDiff = timeDiff, mean = 1000.0, stdDeviation = 100.0) must be(expectedPhi plusOrMinus (0.1)) + } + + // larger stdDeviation results => lower phi + fd.phi(timeDiff = 1100, mean = 1000.0, stdDeviation = 500.0) must be < ( + fd.phi(timeDiff = 1100, mean = 1000.0, stdDeviation = 100.0)) + } + + "return phi value of 0.0 on startup for each address, when no heartbeats" in { + val fd = createFailureDetector() + fd.phi must be(0.0) + fd.phi must be(0.0) + } + + "return phi based on guess when only one heartbeat" in { + val timeInterval = List[Long](0, 1000, 1000, 1000, 1000) + val fd = createFailureDetector(firstHeartbeatEstimate = 1.seconds, + clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() + fd.phi must be(0.3 plusOrMinus 0.2) + fd.phi must be(4.5 plusOrMinus 0.3) + fd.phi must be > (15.0) + } + + "return phi value using first interval after second heartbeat" in { + val timeInterval = List[Long](0, 100, 100, 100) + val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() + fd.phi must be > (0.0) + fd.heartbeat() + fd.phi must be > (0.0) + } + + "mark node as available after a series of successful heartbeats" in { + val timeInterval = List[Long](0, 1000, 100, 100) + val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + + fd.isAvailable must be(true) + } + + "mark node as dead if heartbeat are missed" in { + val timeInterval = List[Long](0, 1000, 100, 100, 7000) + val fd = createFailureDetector(threshold = 3, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() //0 + fd.heartbeat() //1000 + fd.heartbeat() //1100 + + fd.isAvailable must be(true) //1200 + fd.isAvailable must be(false) //8200 + } + + "mark node as available if it starts heartbeat again after being marked dead due to detection of failure" in { + val timeInterval = List[Long](0, 1000, 100, 1100, 7000, 100, 1000, 100, 100) + val fd = createFailureDetector(threshold = 3, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() //0 + fd.heartbeat() //1000 + fd.heartbeat() //1100 + fd.isAvailable must be(true) //1200 + fd.isAvailable must be(false) //8200 + fd.heartbeat() //8300 + fd.heartbeat() //9300 + fd.heartbeat() //9400 + + fd.isAvailable must be(true) //9500 + } + + "accept some configured missing heartbeats" in { + val timeInterval = List[Long](0, 1000, 1000, 1000, 4000, 1000, 1000) + val fd = createFailureDetector(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.isAvailable must be(true) + fd.heartbeat() + fd.isAvailable must be(true) + } + + "fail after configured acceptable missing heartbeats" in { + val timeInterval = List[Long](0, 1000, 1000, 1000, 1000, 1000, 500, 500, 5000) + val fd = createFailureDetector(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.heartbeat() + fd.isAvailable must be(true) + fd.heartbeat() + fd.isAvailable must be(false) + } + + "use maxSampleSize heartbeats" in { + val timeInterval = List[Long](0, 100, 100, 100, 100, 600, 1000, 1000, 1000, 1000, 1000) + val fd = createFailureDetector(maxSampleSize = 3, clock = fakeTimeGenerator(timeInterval)) + + // 100 ms interval + fd.heartbeat() //0 + fd.heartbeat() //100 + fd.heartbeat() //200 + fd.heartbeat() //300 + val phi1 = fd.phi //400 + // 1000 ms interval, should become same phi when 100 ms intervals have been dropped + fd.heartbeat() //1000 + fd.heartbeat() //2000 + fd.heartbeat() //3000 + fd.heartbeat() //4000 + val phi2 = fd.phi //5000 + phi2 must be(phi1.plusOrMinus(0.001)) + } + + } + + "Statistics for heartbeats" must { + + "calculate correct mean and variance" in { + val samples = Seq(100, 200, 125, 340, 130) + val stats = (HeartbeatHistory(maxSampleSize = 20) /: samples) { + (stats, value) ⇒ stats :+ value + } + stats.mean must be(179.0 plusOrMinus 0.00001) + stats.variance must be(7584.0 plusOrMinus 0.00001) + } + + "have 0.0 variance for one sample" in { + (HeartbeatHistory(600) :+ 1000L).variance must be(0.0 plusOrMinus 0.00001) + } + + "be capped by the specified maxSampleSize" in { + val history3 = HeartbeatHistory(maxSampleSize = 3) :+ 100 :+ 110 :+ 90 + history3.mean must be(100.0 plusOrMinus 0.00001) + history3.variance must be(66.6666667 plusOrMinus 0.00001) + + val history4 = history3 :+ 140 + history4.mean must be(113.333333 plusOrMinus 0.00001) + history4.variance must be(422.222222 plusOrMinus 0.00001) + + val history5 = history4 :+ 80 + history5.mean must be(103.333333 plusOrMinus 0.00001) + history5.variance must be(688.88888889 plusOrMinus 0.00001) + + } + } +} diff --git a/akka-remote/src/test/scala/akka/remote/FailureDetectorRegistrySpec.scala b/akka-remote/src/test/scala/akka/remote/FailureDetectorRegistrySpec.scala new file mode 100644 index 0000000000..871d37accf --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/FailureDetectorRegistrySpec.scala @@ -0,0 +1,135 @@ +package akka.remote + +import akka.remote.FailureDetector.Clock +import scala.concurrent.duration._ +import akka.testkit.AkkaSpec + +class FailureDetectorRegistrySpec extends AkkaSpec("akka.loglevel = INFO") { + + def fakeTimeGenerator(timeIntervals: Seq[Long]): Clock = new Clock { + @volatile var times = timeIntervals.tail.foldLeft(List[Long](timeIntervals.head))((acc, c) ⇒ acc ::: List[Long](acc.last + c)) + override def apply(): Long = { + val currentTime = times.head + times = times.tail + currentTime + } + } + + def createFailureDetector( + threshold: Double = 8.0, + maxSampleSize: Int = 1000, + minStdDeviation: FiniteDuration = 10.millis, + acceptableLostDuration: FiniteDuration = Duration.Zero, + firstHeartbeatEstimate: FiniteDuration = 1.second, + clock: Clock = FailureDetector.defaultClock) = + new PhiAccrualFailureDetector( + threshold, + maxSampleSize, + minStdDeviation, + acceptableLostDuration, + firstHeartbeatEstimate = firstHeartbeatEstimate)(clock = clock) + + def createFailureDetectorRegistry(threshold: Double = 8.0, + maxSampleSize: Int = 1000, + minStdDeviation: FiniteDuration = 10.millis, + acceptableLostDuration: FiniteDuration = Duration.Zero, + firstHeartbeatEstimate: FiniteDuration = 1.second, + clock: Clock = FailureDetector.defaultClock): FailureDetectorRegistry[String] = { + new DefaultFailureDetectorRegistry[String](() ⇒ createFailureDetector( + threshold, + maxSampleSize, + minStdDeviation, + acceptableLostDuration, + firstHeartbeatEstimate, + clock)) + } + + "mark node as available after a series of successful heartbeats" in { + val timeInterval = List[Long](0, 1000, 100, 100) + val fd = createFailureDetectorRegistry(clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + + fd.isAvailable("resource1") must be(true) + } + + "mark node as dead if heartbeat are missed" in { + val timeInterval = List[Long](0, 1000, 100, 100, 4000, 3000) + val fd = createFailureDetectorRegistry(threshold = 3, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") //0 + fd.heartbeat("resource1") //1000 + fd.heartbeat("resource1") //1100 + + fd.isAvailable("resource1") must be(true) //1200 + fd.heartbeat("resource2") //5200, but unrelated resource + fd.isAvailable("resource1") must be(false) //8200 + } + + "accept some configured missing heartbeats" in { + val timeInterval = List[Long](0, 1000, 1000, 1000, 4000, 1000, 1000) + val fd = createFailureDetectorRegistry(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.isAvailable("resource1") must be(true) + fd.heartbeat("resource1") + fd.isAvailable("resource1") must be(true) + } + + "fail after configured acceptable missing heartbeats" in { + val timeInterval = List[Long](0, 1000, 1000, 1000, 1000, 1000, 500, 500, 5000) + val fd = createFailureDetectorRegistry(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.isAvailable("resource1") must be(true) + fd.heartbeat("resource1") + fd.isAvailable("resource1") must be(false) + } + + "mark node as available after explicit removal of connection" in { + val timeInterval = List[Long](0, 1000, 100, 100, 100) + val fd = createFailureDetectorRegistry(clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.heartbeat("resource1") + fd.isAvailable("resource1") must be(true) + fd.remove("resource1") + + fd.isAvailable("resource1") must be(true) + } + + "mark node as available after explicit removal of connection and receiving heartbeat again" in { + val timeInterval = List[Long](0, 1000, 100, 1100, 1100, 1100, 1100, 1100, 100) + val fd = createFailureDetectorRegistry(clock = fakeTimeGenerator(timeInterval)) + + fd.heartbeat("resource1") //0 + + fd.heartbeat("resource1") //1000 + fd.heartbeat("resource1") //1100 + + fd.isAvailable("resource1") must be(true) //2200 + + fd.remove("resource1") + + fd.isAvailable("resource1") must be(true) //3300 + + // it receives heartbeat from an explicitly removed node + fd.heartbeat("resource1") //4400 + fd.heartbeat("resource1") //5500 + fd.heartbeat("resource1") //6600 + + fd.isAvailable("resource1") must be(true) //6700 + } + +} diff --git a/akka-remote/src/test/scala/akka/remote/RemoteCommunicationSpec.scala b/akka-remote/src/test/scala/akka/remote/RemoteCommunicationSpec.scala index 962fad88fc..dd16edade0 100644 --- a/akka-remote/src/test/scala/akka/remote/RemoteCommunicationSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/RemoteCommunicationSpec.scala @@ -100,7 +100,7 @@ akka { "create and supervise children on remote node" in { val r = system.actorOf(Props[Echo], "blub") - r.path.toString must be === "akka://remote-sys@localhost:12346/remote/RemoteCommunicationSpec@localhost:12345/user/blub" + r.path.toString must be === "akka://remote-sys@localhost:12346/remote/akka/RemoteCommunicationSpec@localhost:12345/user/blub" r ! 42 expectMsg(42) EventFilter[Exception]("crash", occurrences = 1).intercept { diff --git a/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala b/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala new file mode 100644 index 0000000000..caa7fe3293 --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala @@ -0,0 +1,301 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ +package akka.remote + +import akka.actor._ +import akka.pattern.ask +import akka.testkit._ +import com.typesafe.config._ +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.duration._ +import akka.remote.transport.AssociationRegistry + +object RemotingSpec { + class Echo extends Actor { + var target: ActorRef = context.system.deadLetters + + def receive = { + case (p: Props, n: String) ⇒ sender ! context.actorOf(Props[Echo], n) + case ex: Exception ⇒ throw ex + case s: String ⇒ sender ! context.actorFor(s) + case x ⇒ target = sender; sender ! x + } + + override def preStart() {} + override def preRestart(cause: Throwable, msg: Option[Any]) { + target ! "preRestart" + } + override def postRestart(cause: Throwable) {} + override def postStop() { + target ! "postStop" + } + } + + val cfg: Config = ConfigFactory parseString (""" + common-transport-settings { + log-transport-events = true + connection-timeout = 120s + use-dispatcher-for-io = "" + write-buffer-high-water-mark = 0b + write-buffer-low-water-mark = 0b + send-buffer-size = 32000b + receive-buffer-size = 32000b + backlog = 4096 + hostname = localhost + enable-ssl = false + } + + common-ssl-settings { + key-store = "%s" + trust-store = "%s" + key-store-password = "changeme" + trust-store-password = "changeme" + protocol = "TLSv1" + random-number-generator = "AES128CounterSecureRNG" + enabled-algorithms = [TLS_RSA_WITH_AES_128_CBC_SHA] + sha1prng-random-source = "/dev/./urandom" + } + + akka { + actor.provider = "akka.remote.RemoteActorRefProvider" + remote.transport = "akka.remote.Remoting" + + remoting.retry-latch-closed-for = 1 s + remoting.log-remote-lifecycle-events = on + + remoting.transports = [ + { + transport-class = "akka.remote.transport.TestTransport" + settings { + registry-key = aX33k0jWKg + local-address = "test://RemotingSpec@localhost:12345" + maximum-payload-bytes = 32000 bytes + scheme-identifier = test + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = tcp + port = 12345 + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = udp + port = 12345 + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = tcp + enable-ssl = true + port = 23456 + ssl = ${common-ssl-settings} + } + } + ] + + actor.deployment { + /blub.remote = "test.akka://remote-sys@localhost:12346" + /gonk.remote = "tcp.akka://remote-sys@localhost:12346" + /zagzag.remote = "udp.akka://remote-sys@localhost:12346" + /roghtaar.remote = "tcp.ssl.akka://remote-sys@localhost:23457" + /looker/child.remote = "test.akka://remote-sys@localhost:12346" + /looker/child/grandchild.remote = "test.akka://RemotingSpec@localhost:12345" + } + } +""".format( + getClass.getClassLoader.getResource("keystore").getPath, + getClass.getClassLoader.getResource("truststore").getPath)) + +} + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with DefaultTimeout { + + import RemotingSpec._ + + val conf = ConfigFactory.parseString( + """ + akka.remote.netty.port=12346 + akka.remoting.transports = [ + { + transport-class = "akka.remote.transport.TestTransport" + settings { + registry-key = aX33k0jWKg + local-address = "test://remote-sys@localhost:12346" + maximum-payload-bytes = 32000 bytes + scheme-identifier = test + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = tcp + port = 12346 + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = udp + port = 12346 + } + }, + { + transport-class = "akka.remote.transport.netty.NettyTransport" + settings = ${common-transport-settings} + settings { + transport-protocol = tcp + enable-ssl = true + port = 23457 + ssl = ${common-ssl-settings} + } + } + ] + """).withFallback(system.settings.config).resolve() + val other = ActorSystem("remote-sys", conf) + + val remote = other.actorOf(Props(new Actor { + def receive = { + case "ping" ⇒ sender ! (("pong", sender)) + } + }), "echo") + + val here = system.actorFor("test.akka://remote-sys@localhost:12346/user/echo") + + override def atTermination() { + other.shutdown() + AssociationRegistry.clear() + } + + "Remoting" must { + + "support remote look-ups" in { + here ! "ping" + expectMsgPF() { + case ("pong", s: AnyRef) if s eq testActor ⇒ true + } + } + + "send error message for wrong address" in { + EventFilter.error(start = "AssociationError", occurrences = 1).intercept { + system.actorFor("test.akka://remotesys@localhost:12346/user/echo") ! "ping" + } + } + + "support ask" in { + Await.result(here ? "ping", timeout.duration) match { + case ("pong", s: akka.pattern.PromiseActorRef) ⇒ // good + case m ⇒ fail(m + " was not (pong, AskActorRef)") + } + } + + "send dead letters on remote if actor does not exist" in { + EventFilter.warning(pattern = "dead.*buh", occurrences = 1).intercept { + system.actorFor("test.akka://remote-sys@localhost:12346/does/not/exist") ! "buh" + }(other) + } + + "create and supervise children on remote node" in { + val r = system.actorOf(Props[Echo], "blub") + r.path.toString must be === "test.akka://remote-sys@localhost:12346/remote/test.akka/RemotingSpec@localhost:12345/user/blub" + r ! 42 + expectMsg(42) + EventFilter[Exception]("crash", occurrences = 1).intercept { + r ! new Exception("crash") + }(other) + expectMsg("preRestart") + r ! 42 + expectMsg(42) + system.stop(r) + expectMsg("postStop") + } + + "look-up actors across node boundaries" in { + val l = system.actorOf(Props(new Actor { + def receive = { + case (p: Props, n: String) ⇒ sender ! context.actorOf(p, n) + case s: String ⇒ sender ! context.actorFor(s) + } + }), "looker") + l ! (Props[Echo], "child") + val r = expectMsgType[ActorRef] + r ! (Props[Echo], "grandchild") + val remref = expectMsgType[ActorRef] + remref.asInstanceOf[ActorRefScope].isLocal must be(true) + val myref = system.actorFor(system / "looker" / "child" / "grandchild") + myref.isInstanceOf[RemoteActorRef] must be(true) + myref ! 43 + expectMsg(43) + lastSender must be theSameInstanceAs remref + r.asInstanceOf[RemoteActorRef].getParent must be(l) + system.actorFor("/user/looker/child") must be theSameInstanceAs r + Await.result(l ? "child/..", timeout.duration).asInstanceOf[AnyRef] must be theSameInstanceAs l + Await.result(system.actorFor(system / "looker" / "child") ? "..", timeout.duration).asInstanceOf[AnyRef] must be theSameInstanceAs l + } + + "not fail ask across node boundaries" in { + import system.dispatcher + val f = for (_ ← 1 to 1000) yield here ? "ping" mapTo manifest[(String, ActorRef)] + Await.result(Future.sequence(f), remaining).map(_._1).toSet must be(Set("pong")) + } + + "be able to use multiple transports and use the appropriate one (TCP)" in { + val r = system.actorOf(Props[Echo], "gonk") + r.path.toString must be === "tcp.akka://remote-sys@localhost:12346/remote/tcp.akka/RemotingSpec@localhost:12345/user/gonk" + r ! 42 + expectMsg(42) + EventFilter[Exception]("crash", occurrences = 1).intercept { + r ! new Exception("crash") + }(other) + expectMsg("preRestart") + r ! 42 + expectMsg(42) + system.stop(r) + expectMsg("postStop") + } + + "be able to use multiple transports and use the appropriate one (UDP)" in { + val r = system.actorOf(Props[Echo], "zagzag") + r.path.toString must be === "udp.akka://remote-sys@localhost:12346/remote/udp.akka/RemotingSpec@localhost:12345/user/zagzag" + r ! 42 + expectMsg(10 seconds, 42) + EventFilter[Exception]("crash", occurrences = 1).intercept { + r ! new Exception("crash") + }(other) + expectMsg("preRestart") + r ! 42 + expectMsg(42) + system.stop(r) + expectMsg("postStop") + } + + "be able to use multiple transports and use the appropriate one (SSL)" in { + val r = system.actorOf(Props[Echo], "roghtaar") + r.path.toString must be === "tcp.ssl.akka://remote-sys@localhost:23457/remote/tcp.ssl.akka/RemotingSpec@localhost:23456/user/roghtaar" + r ! 42 + expectMsg(10 seconds, 42) + EventFilter[Exception]("crash", occurrences = 1).intercept { + r ! new Exception("crash") + }(other) + expectMsg("preRestart") + r ! 42 + expectMsg(42) + system.stop(r) + expectMsg("postStop") + } + + } + +} diff --git a/akka-remote/src/test/scala/akka/remote/Ticket1978CommunicationSpec.scala b/akka-remote/src/test/scala/akka/remote/Ticket1978CommunicationSpec.scala index c194fe1fa6..8b4fc8a444 100644 --- a/akka-remote/src/test/scala/akka/remote/Ticket1978CommunicationSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/Ticket1978CommunicationSpec.scala @@ -60,17 +60,21 @@ object Configuration { val fullConfig = config.withFallback(AkkaSpec.testConf).withFallback(ConfigFactory.load).getConfig("akka.remote.netty") val settings = new NettySettings(fullConfig, "placeholder") - val rng = NettySSLSupport.initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, NoLogging) + val rng = NettySSLSupport.initializeCustomSecureRandom(settings.SslSettings.SSLRandomNumberGenerator, + settings.SslSettings.SSLRandomSource, NoLogging) rng.nextInt() // Has to work - settings.SSLRandomNumberGenerator foreach { sRng ⇒ rng.getAlgorithm == sRng || (throw new NoSuchAlgorithmException(sRng)) } + settings.SslSettings.SSLRandomNumberGenerator foreach { + sRng ⇒ rng.getAlgorithm == sRng || (throw new NoSuchAlgorithmException(sRng)) + } - val engine = NettySSLSupport.initializeClientSSL(settings, NoLogging).getEngine + val engine = NettySSLSupport.initializeClientSSL(settings.SslSettings, NoLogging).getEngine val gotAllSupported = enabled.toSet -- engine.getSupportedCipherSuites.toSet val gotAllEnabled = enabled.toSet -- engine.getEnabledCipherSuites.toSet gotAllSupported.isEmpty || (throw new IllegalArgumentException("Cipher Suite not supported: " + gotAllSupported)) gotAllEnabled.isEmpty || (throw new IllegalArgumentException("Cipher Suite not enabled: " + gotAllEnabled)) - engine.getSupportedProtocols.contains(settings.SSLProtocol.get) || (throw new IllegalArgumentException("Protocol not supported: " + settings.SSLProtocol.get)) + engine.getSupportedProtocols.contains(settings.SslSettings.SSLProtocol.get) || + (throw new IllegalArgumentException("Protocol not supported: " + settings.SslSettings.SSLProtocol.get)) CipherConfig(true, config, cipher, localPort, remotePort) } catch { @@ -131,7 +135,7 @@ abstract class Ticket1978CommunicationSpec(val cipherConfig: CipherConfig) exten ("-") must { if (cipherConfig.runTest) { val ignoreMe = other.actorOf(Props(new Actor { def receive = { case ("ping", x) ⇒ sender ! ((("pong", x), sender)) } }), "echo") - val otherAddress = other.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.address + val otherAddress = other.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.addresses.head "support tell" in { val here = system.actorFor(otherAddress.toString + "/user/echo") diff --git a/akka-remote/src/test/scala/akka/remote/Ticket1978ConfigSpec.scala b/akka-remote/src/test/scala/akka/remote/Ticket1978ConfigSpec.scala index e088ae3362..a1836c9d47 100644 --- a/akka-remote/src/test/scala/akka/remote/Ticket1978ConfigSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/Ticket1978ConfigSpec.scala @@ -28,14 +28,14 @@ akka { import settings._ EnableSSL must be(false) - SSLKeyStore must be(Some("keystore")) - SSLKeyStorePassword must be(Some("changeme")) - SSLTrustStore must be(Some("truststore")) - SSLTrustStorePassword must be(Some("changeme")) - SSLProtocol must be(Some("TLSv1")) - SSLEnabledAlgorithms must be(Set("TLS_RSA_WITH_AES_128_CBC_SHA")) - SSLRandomSource must be(None) - SSLRandomNumberGenerator must be(None) + SslSettings.SSLKeyStore must be(Some("keystore")) + SslSettings.SSLKeyStorePassword must be(Some("changeme")) + SslSettings.SSLTrustStore must be(Some("truststore")) + SslSettings.SSLTrustStorePassword must be(Some("changeme")) + SslSettings.SSLProtocol must be(Some("TLSv1")) + SslSettings.SSLEnabledAlgorithms must be(Set("TLS_RSA_WITH_AES_128_CBC_SHA")) + SslSettings.SSLRandomSource must be(None) + SslSettings.SSLRandomNumberGenerator must be(None) } } } diff --git a/akka-remote/src/test/scala/akka/remote/UntrustedSpec.scala b/akka-remote/src/test/scala/akka/remote/UntrustedSpec.scala index 58ace1bb7c..e2f9f1e6ee 100644 --- a/akka-remote/src/test/scala/akka/remote/UntrustedSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/UntrustedSpec.scala @@ -34,7 +34,7 @@ akka.loglevel = DEBUG akka.actor.provider = akka.remote.RemoteActorRefProvider akka.remote.netty.port = 0 """)) - val addr = system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.address + val addr = system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.addresses.head val target1 = other.actorFor(RootActorPath(addr) / "remote") val target2 = other.actorFor(RootActorPath(addr) / testActor.path.elements) diff --git a/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala b/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala new file mode 100644 index 0000000000..a92f98bc6c --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/transport/AkkaProtocolSpec.scala @@ -0,0 +1,471 @@ +package akka.remote.transport + +import akka.actor.{ ExtendedActorSystem, Address, Props } +import akka.remote.transport.AkkaPduCodec.{ Disassociate, Associate, Heartbeat } +import akka.remote.transport.AkkaProtocolSpec.TestFailureDetector +import akka.remote.transport.AssociationHandle.{ Disassociated, InboundPayload } +import akka.remote.transport.TestTransport._ +import akka.remote.transport.Transport._ +import akka.remote.{ RemoteProtocol, RemoteActorRefProvider, FailureDetector } +import akka.testkit.{ ImplicitSender, AkkaSpec } +import akka.util.ByteString +import com.google.protobuf.{ ByteString ⇒ PByteString } +import com.typesafe.config.ConfigFactory +import scala.concurrent.duration._ +import scala.concurrent.{ Await, Promise } + +object AkkaProtocolSpec { + + class TestFailureDetector extends FailureDetector { + @volatile var isAvailable: Boolean = true + + @volatile var called: Boolean = false + + def heartbeat(): Unit = called = true + } + +} + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class AkkaProtocolSpec extends AkkaSpec("""akka.actor.provider = "akka.remote.RemoteActorRefProvider" """) with ImplicitSender { + + val conf = ConfigFactory.parseString( + """ + | akka.remoting { + | + | failure-detector { + | threshold = 7.0 + | max-sample-size = 100 + | min-std-deviation = 100 ms + | acceptable-heartbeat-pause = 3 s + | } + | + | heartbeat-interval = 0.1 s + | + | wait-activity-enabled = on + | + | backoff-interval = 1 s + | + | require-cookie = off + | + | secure-cookie = "abcde" + | + | shutdown-timeout = 5 s + | + | startup-timeout = 5 s + | + | retry-latch-closed-for = 0 s + | + | use-passive-connections = on + | } + """.stripMargin) + + val localAddress = Address("test", "testsystem", "testhost", 1234) + val localAkkaAddress = Address("test.akka", "testsystem", "testhost", 1234) + + val remoteAddress = Address("test", "testsystem2", "testhost2", 1234) + val remoteAkkaAddress = Address("test.akka", "testsystem2", "testhost2", 1234) + + val codec = AkkaPduProtobufCodec + + val provider = system.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider] + + val testMsg = RemoteProtocol.MessageProtocol.newBuilder().setSerializerId(0).setMessage(PByteString.copyFromUtf8("foo")).build + val testMsgPdu: ByteString = codec.constructMessagePdu(localAkkaAddress, self, testMsg, None) + + def testHeartbeat = InboundPayload(codec.constructHeartbeat) + def testPayload = InboundPayload(testMsgPdu) + + def testDisassociate = InboundPayload(codec.constructDisassociate) + def testAssociate(cookie: Option[String]) = InboundPayload(codec.constructAssociate(cookie, remoteAkkaAddress)) + + def collaborators = { + val registry = new AssociationRegistry + val transport: TestTransport = new TestTransport(localAddress, registry) + val handle: TestAssociationHandle = new TestAssociationHandle(localAddress, remoteAddress, transport, true) + + // silently drop writes -- we do not have another endpoint under test, so nobody to forward to + transport.writeBehavior.pushConstant(true) + (new TestFailureDetector, registry, transport, handle) + } + + def lastActivityIsHeartbeat(registry: AssociationRegistry) = if (registry.logSnapshot.isEmpty) false else registry.logSnapshot.last match { + case WriteAttempt(sender, recipient, payload) if sender == localAddress && recipient == remoteAddress ⇒ + codec.decodePdu(payload, provider) match { + case Heartbeat ⇒ true + case _ ⇒ false + } + case _ ⇒ false + } + + def lastActivityIsAssociate(registry: AssociationRegistry, cookie: Option[String]) = if (registry.logSnapshot.isEmpty) false else registry.logSnapshot.last match { + case WriteAttempt(sender, recipient, payload) if sender == localAddress && recipient == remoteAddress ⇒ + codec.decodePdu(payload, provider) match { + case Associate(c, origin) if c == cookie && origin == localAddress ⇒ true + case _ ⇒ false + } + case _ ⇒ false + } + + def lastActivityIsDisassociate(registry: AssociationRegistry) = if (registry.logSnapshot.isEmpty) false else registry.logSnapshot.last match { + case WriteAttempt(sender, recipient, payload) if sender == localAddress && recipient == remoteAddress ⇒ + codec.decodePdu(payload, provider) match { + case Disassociate ⇒ true + case _ ⇒ false + } + case _ ⇒ false + } + + "ProtocolStateActor" must { + + "register itself as reader on injecteted handles" in { + val (failureDetector, _, _, handle) = collaborators + + system.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + self, + new AkkaProtocolSettings(conf), + codec, + failureDetector))) + + awaitCond(handle.readHandlerPromise.isCompleted) + } + + "in inbound mode accept payload after Associate PDU received" in { + val (failureDetector, registry, _, handle) = collaborators + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + self, + new AkkaProtocolSettings(conf), + codec, + failureDetector))) + + reader ! testAssociate(None) + + awaitCond(failureDetector.called) + + val wrappedHandle = expectMsgPF() { + case InboundAssociation(h) ⇒ h + } + + wrappedHandle.readHandlerPromise.success(self) + + failureDetector.called must be(true) + + // Heartbeat was sent in response to Associate + awaitCond(lastActivityIsHeartbeat(registry)) + + reader ! testPayload + + expectMsgPF() { + case InboundPayload(p) ⇒ p must be === testMsgPdu + } + } + + "in inbound mode disassociate when an unexpected message arrives instead of Associate" in { + val (failureDetector, registry, _, handle) = collaborators + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + self, + new AkkaProtocolSettings(conf), + codec, + failureDetector))) + + // a stray message will force a disassociate + reader ! testHeartbeat + + // this associate will now be ignored + reader ! testAssociate(None) + + awaitCond(registry.logSnapshot.exists { + case DisassociateAttempt(requester, remote) ⇒ true + case _ ⇒ false + }) + } + + "serve the handle as soon as possible if WaitActivity is turned off" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.wait-activity-enabled = off").withFallback(conf)), + codec, + failureDetector))) + + Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + + case _ ⇒ fail() + } + + lastActivityIsAssociate(registry, None) must be(true) + failureDetector.called must be(true) + + } + + "in outbound mode with WaitActivity delay readiness until activity detected" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(conf), + codec, + failureDetector))) + + awaitCond(lastActivityIsAssociate(registry, None)) + failureDetector.called must be(true) + + // keeps sending heartbeats + awaitCond(lastActivityIsHeartbeat(registry)) + + statusPromise.isCompleted must be(false) + + // finish connection by sending back a payload + reader ! testPayload + + Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + + case _ ⇒ fail() + } + + } + + "ignore incoming associations with wrong cookie" in { + val (failureDetector, registry, _, handle) = collaborators + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + self, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.require-cookie = on").withFallback(conf)), + codec, + failureDetector))) + + reader ! testAssociate(Some("xyzzy")) + + awaitCond(registry.logSnapshot.exists { + case DisassociateAttempt(requester, remote) ⇒ true + case _ ⇒ false + }) + } + + "accept incoming associations with correct cookie" in { + val (failureDetector, registry, _, handle) = collaborators + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + handle, + self, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.require-cookie = on").withFallback(conf)), + codec, + failureDetector))) + + // Send the correct cookie + reader ! testAssociate(Some("abcde")) + + val wrappedHandle = expectMsgPF() { + case InboundAssociation(h) ⇒ h + } + + wrappedHandle.readHandlerPromise.success(self) + + failureDetector.called must be(true) + + // Heartbeat was sent in response to Associate + awaitCond(lastActivityIsHeartbeat(registry)) + } + + "send cookie in Associate PDU if configured to do so" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(ConfigFactory.parseString( + """ + | akka.remoting.require-cookie = on + | akka.remoting.wait-activity-enabled = off + """.stripMargin).withFallback(conf)), + codec, + failureDetector))) + + Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + + case _ ⇒ fail() + } + + lastActivityIsAssociate(registry, Some("abcde")) must be(true) + } + + "handle explicit disassociate messages" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.wait-activity-enabled = off").withFallback(conf)), + codec, + failureDetector))) + + val wrappedHandle = Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + h + + case _ ⇒ fail() + } + + wrappedHandle.readHandlerPromise.success(self) + + lastActivityIsAssociate(registry, None) must be(true) + + reader ! testDisassociate + + expectMsg(Disassociated) + } + + "handle transport level disassociations" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + val reader = system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(conf), + codec, + failureDetector))) + + awaitCond(lastActivityIsAssociate(registry, None)) + + // Finish association with a heartbeat -- pushes state out of WaitActivity + reader ! testHeartbeat + + val wrappedHandle = Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + h + + case _ ⇒ fail() + } + + wrappedHandle.readHandlerPromise.success(self) + + Thread.sleep(100) + + reader ! Disassociated + + expectMsg(Disassociated) + } + + "disassociate when failure detector signals failure" in { + val (failureDetector, registry, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.wait-activity-enabled = off").withFallback(conf)), + codec, + failureDetector))) + + val wrappedHandle = Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + h + + case _ ⇒ fail() + } + + wrappedHandle.readHandlerPromise.success(self) + + lastActivityIsAssociate(registry, None) must be(true) + + //wait for one heartbeat + awaitCond(lastActivityIsHeartbeat(registry)) + + failureDetector.isAvailable = false + + expectMsg(Disassociated) + } + + "handle correctly when the handler is registered only after the association is already closed" in { + val (failureDetector, _, transport, handle) = collaborators + transport.associateBehavior.pushConstant(Transport.Ready(handle)) + + val statusPromise: Promise[Status] = Promise() + + val stateActor = system.actorOf(Props(new ProtocolStateActor( + localAddress, + remoteAddress, + statusPromise, + transport, + new AkkaProtocolSettings(ConfigFactory.parseString("akka.remoting.wait-activity-enabled = off").withFallback(conf)), + codec, + failureDetector))) + + val wrappedHandle = Await.result(statusPromise.future, 3 seconds) match { + case Transport.Ready(h) ⇒ + h.remoteAddress must be === remoteAkkaAddress + h.localAddress must be === localAkkaAddress + h + + case _ ⇒ fail() + } + + stateActor ! Disassociated + + wrappedHandle.readHandlerPromise.success(self) + + expectMsg(Disassociated) + + } + + } + +} diff --git a/akka-remote/src/test/scala/akka/remote/transport/SwitchableLoggedBehaviorSpec.scala b/akka-remote/src/test/scala/akka/remote/transport/SwitchableLoggedBehaviorSpec.scala new file mode 100644 index 0000000000..54fb1194b2 --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/transport/SwitchableLoggedBehaviorSpec.scala @@ -0,0 +1,104 @@ +package akka.remote.transport + +import akka.testkit.{ DefaultTimeout, AkkaSpec } +import akka.remote.transport.TestTransport.SwitchableLoggedBehavior +import scala.concurrent.{ Await, Promise } +import scala.util.Failure +import akka.AkkaException + +object SwitchableLoggedBehaviorSpec { + object TestException extends AkkaException("Test exception") +} + +class SwitchableLoggedBehaviorSpec extends AkkaSpec with DefaultTimeout { + import akka.remote.transport.SwitchableLoggedBehaviorSpec._ + + private def defaultBehavior = new SwitchableLoggedBehavior[Unit, Int]((_) ⇒ Promise.successful(3).future, (_) ⇒ ()) + + "A SwitchableLoggedBehavior" must { + + "execute default behavior" in { + val behavior = defaultBehavior + + Await.result(behavior(), timeout.duration) == 3 must be(true) + } + + "be able to push generic behavior" in { + val behavior = defaultBehavior + + behavior.push((_) ⇒ Promise.successful(4).future) + Await.result(behavior(), timeout.duration) must be(4) + + behavior.push((_) ⇒ Promise.failed(TestException).future) + behavior().value match { + case Some(Failure(e)) if e eq TestException ⇒ + case _ ⇒ fail("Expected exception") + } + } + + "be able to push constant behavior" in { + val behavior = defaultBehavior + behavior.pushConstant(5) + + Await.result(behavior(), timeout.duration) must be(5) + Await.result(behavior(), timeout.duration) must be(5) + } + + "be able to push failure behavior" in { + val behavior = defaultBehavior + behavior.pushError(TestException) + + behavior().value match { + case Some(Failure(e)) if e eq TestException ⇒ + case _ ⇒ fail("Expected exception") + } + } + + "be able to push and pop behavior" in { + val behavior = defaultBehavior + + behavior.pushConstant(5) + Await.result(behavior(), timeout.duration) must be(5) + + behavior.pushConstant(7) + Await.result(behavior(), timeout.duration) must be(7) + + behavior.pop() + Await.result(behavior(), timeout.duration) must be(5) + + behavior.pop() + Await.result(behavior(), timeout.duration) must be(3) + + } + + "protect the default behavior from popped out" in { + val behavior = defaultBehavior + behavior.pop() + behavior.pop() + behavior.pop() + + Await.result(behavior(), timeout.duration) must be(3) + } + + "enable delayed completition" in { + val behavior = defaultBehavior + val controlPromise = behavior.pushDelayed + val f = behavior() + + f.isCompleted must be(false) + controlPromise.success(()) + + awaitCond(f.isCompleted) + } + + "log calls and parametrers" in { + val logPromise = Promise[Int]() + val behavior = new SwitchableLoggedBehavior[Int, Int]((i) ⇒ Promise.successful(3).future, (i) ⇒ logPromise.success(i)) + + behavior(11) + Await.result(logPromise.future, timeout.duration) must be(11) + } + + } + +} diff --git a/akka-remote/src/test/scala/akka/remote/transport/TestTransport.scala b/akka-remote/src/test/scala/akka/remote/transport/TestTransport.scala new file mode 100644 index 0000000000..d49b40c444 --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/transport/TestTransport.scala @@ -0,0 +1,454 @@ +package akka.remote.transport + +import TestTransport._ +import akka.actor._ +import akka.remote.transport.AssociationHandle._ +import akka.remote.transport.Transport._ +import akka.util.ByteString +import com.typesafe.config.Config +import java.util.concurrent.{ CopyOnWriteArrayList, ConcurrentHashMap } +import scala.concurrent.duration._ +import scala.concurrent.{ Await, Future, Promise } + +// Default EC is used, but this is just a test utility -- please forgive... +import scala.concurrent.ExecutionContext.Implicits.global + +object TestTransport { + + type Behavior[A, B] = (A) ⇒ Future[B] + + /** + * Test utility to make behavior of functions that return some Future[B] controllable from tests. This tool is able + * to overwrite default behavior with any generic behavior, including failure, and exposes control to the timing of + * the completition of the returned future. + * + * The utility is implemented as a stack of behaviors, where the behavior on the top of the stack represents the + * currently active behavior. The bottom of the stack always contains the defaultBehavior which can not be popped + * out. + * + * @param defaultBehavior + * The original behavior that might be overwritten. It is always possible to restore this behavior + * + * @param logCallback + * Function that will be called independently of the current active behavior + * + * @tparam A + * Parameter type of the wrapped function. If it takes multiple parameters it must be wrapped in a tuple. + * + * @tparam B + * Type parameter of the future that the original function returns. + */ + class SwitchableLoggedBehavior[A, B](defaultBehavior: Behavior[A, B], logCallback: (A) ⇒ Unit) extends Behavior[A, B] { + + private val behaviorStack = new CopyOnWriteArrayList[Behavior[A, B]]() + behaviorStack.add(0, defaultBehavior) + + /** + * Changes the current behavior to the provided one. + * + * @param behavior + * Function that takes a parameter type A and returns a Future[B]. + */ + def push(behavior: Behavior[A, B]): Unit = { + behaviorStack.add(0, behavior) + } + + /** + * Changes the behavior to return a completed future with the given constant value. + * + * @param c + * The constant the future will be completed with. + */ + def pushConstant(c: B): Unit = push { + (x) ⇒ Promise.successful(c).future + } + + /** + * Changes the current behavior to return a failed future containing the given Throwable. + * + * @param e + * The throwable the failed future will contain. + */ + def pushError(e: Throwable): Unit = push { + (x) ⇒ Promise.failed(e).future + } + + /** + * Enables control of the completion of the previously active behavior. Wraps the previous behavior in a new + * one, returns a control promise that starts the original behavior after the control promise is completed. + * + * @return + * A promise, which delays the completion of the original future until after this promise is completed. + */ + def pushDelayed: Promise[Unit] = { + val controlPromise: Promise[Unit] = Promise() + val originalBehavior = currentBehavior + + push( + (params: A) ⇒ for (delayed ← controlPromise.future; original ← originalBehavior(params)) yield original) + + controlPromise + } + + /** + * Restores the previous behavior. + */ + def pop(): Unit = { + if (behaviorStack.size > 1) { + behaviorStack.remove(0) + } + } + + private def currentBehavior = behaviorStack.get(0) + + /** + * Applies the current behavior, and invokes the callback. + * + * @param params + * The parameters of this behavior. + * @return + * The result of this behavior wrapped in a future. + */ + def apply(params: A): Future[B] = { + logCallback(params) + currentBehavior(params) + } + } + + /** + * Base trait for activities that are logged by [[akka.remote.transport.TestTransport]]. + */ + sealed trait Activity + + case class ListenAttempt(boundAddress: Address) extends Activity + case class AssociateAttempt(localAddress: Address, remoteAddress: Address) extends Activity + case class ShutdownAttempt(boundAddress: Address) extends Activity + case class WriteAttempt(sender: Address, recipient: Address, payload: ByteString) extends Activity + case class DisassociateAttempt(requester: Address, remote: Address) extends Activity + + /** + * Shared state among [[akka.remote.transport.TestTransport]] instances. Coordinates the transports and the means + * of communication between them. + */ + class AssociationRegistry { + + private val activityLog = new CopyOnWriteArrayList[Activity]() + private val transportTable = new ConcurrentHashMap[Address, (TestTransport, ActorRef)]() + private val handlersTable = new ConcurrentHashMap[(Address, Address), Future[(ActorRef, ActorRef)]]() + + /** + * Logs a transport activity. + * + * @param activity Activity to be logged. + */ + def logActivity(activity: Activity): Unit = { + activityLog.add(activity) + } + + /** + * Takes a thread-safe snapshot of the current state of the activity log. + * + * @return Collection containing activities ordered left-to-right according to time (first element is earliest). + */ + def logSnapshot: Seq[Activity] = { + var result = List[Activity]() + + val it = activityLog.iterator() + while (it.hasNext) result ::= it.next() + + result.reverse + } + + /** + * Clears the activity log. + */ + def clearLog(): Unit = { + activityLog.clear() + } + + /** + * Records a mapping between an address and the corresponding (transport, actor) pair. + * + * @param transport + * The transport that is to be registered. The address of this transport will be used as key. + * @param responsibleActor + * The actor that will handle the events for the given transport. + */ + def registerTransport(transport: TestTransport, responsibleActor: ActorRef): Unit = { + transportTable.put(transport.localAddress, (transport, responsibleActor)) + } + + /** + * Indicates if all given transports were successfully registered. No associations can be established between + * transports that are not yet registered. + * + * @param transports + * The transports that participate in the test case. + * @return + * True if all transports are successfully registered. + */ + def transportsReady(transports: TestTransport*): Boolean = { + transports forall { + t ⇒ transportTable.containsKey(t.localAddress) + } + } + + /** + * Registers a Future of two actors corresponding to the two endpoints of an association. + * + * @param key + * Ordered pair of addresses representing an association. First element must be the address of the initiator. + * @param readHandlers + * The future containing the actors that will be responsible for handling the events of the two endpoints of the + * association. Elements in the pair must be in the same order as the addresses in the key parameter. + */ + def registerHandlePair(key: (Address, Address), readHandlers: Future[(ActorRef, ActorRef)]): Unit = { + handlersTable.put(key, readHandlers) + } + + /** + * Removes an association. + * @param key + * Ordered pair of addresses representing an association. First element is the address of the initiator. + * @return + * The original entries. + */ + def deregisterAssociation(key: (Address, Address)): Option[Future[(ActorRef, ActorRef)]] = + Option(handlersTable.remove(key)) + + /** + * Tests if an association was registered. + * + * @param initiatorAddress The initiator of the association. + * @param remoteAddress The other address of the association. + * + * @return True if there is an association for the given addresses. + */ + def existsAssociation(initiatorAddress: Address, remoteAddress: Address): Boolean = { + handlersTable.containsKey((initiatorAddress, remoteAddress)) + } + + /** + * Returns the event handler actor corresponding to the remote endpoint of the given local handle. In other words + * it returns the actor that will receive InboundPayload events when {{{write()}}} is called on the given handle. + * + * @param localHandle The handle + * @return The option that contains the Future for the handler actor if exists. + */ + def getRemoteReadHandlerFor(localHandle: TestAssociationHandle): Option[Future[ActorRef]] = { + Option(handlersTable.get(localHandle.key)) map { + case pairFuture: Future[(ActorRef, ActorRef)] ⇒ if (localHandle.inbound) { + pairFuture.map { _._1 } + } else { + pairFuture.map { _._2 } + } + } + } + + /** + * Returns the Transport bound to the given address. + * + * @param address The address bound to the transport. + * @return The transport if exists. + */ + def transportFor(address: Address): Option[(TestTransport, ActorRef)] = Option(transportTable.get(address)) + + /** + * Resets the state of the registry. ''Warning!'' This method is not atomic. + */ + def reset(): Unit = { + clearLog() + transportTable.clear() + handlersTable.clear() + } + } + +} + +/* + NOTE: This is a global shared state between different actor systems. The purpose of this class is to allow dynamically + loaded TestTransports to set up a shared AssociationRegistry. Extensions could not be used for this purpose, as the injection + of the shared instance must happen during the startup time of the actor system. Association registries are looked + up via a string key. Until we find a better way to inject an AssociationRegistry to multiple actor systems it is + strongly recommended to use long, randomly generated strings to key the registry to avoid interference between tests. + */ +object AssociationRegistry { + private final val registries = scala.collection.mutable.Map[String, AssociationRegistry]() + + def get(key: String): AssociationRegistry = this.synchronized { + registries.getOrElseUpdate(key, new AssociationRegistry) + } + + def clear(): Unit = this.synchronized { registries.clear() } +} + +/** + * Transport implementation to be used for testing. + * + * The TestTransport is basically a shared memory between actor systems. The TestTransport could be programmed to + * emulate different failure modes of a Transport implementation. TestTransport keeps a log of the activities it was + * requested to do. This class is not optimized for performace and MUST not be used as an in-memory transport in + * production systems. + */ +class TestTransport( + val localAddress: Address, + final val registry: AssociationRegistry, + val maximumPayloadBytes: Int = 32000, + val schemeIdentifier: String = "test") extends Transport { + + def this(system: ExtendedActorSystem, conf: Config) = { + this( + AddressFromURIString(conf.getString("local-address")), + AssociationRegistry.get(conf.getString("registry-key")), + conf.getBytes("maximum-payload-bytes").toInt, + conf.getString("scheme-identifier")) + } + + import akka.remote.transport.TestTransport._ + + override def isResponsibleFor(address: Address): Boolean = true + + private val actorPromise = Promise[ActorRef]() + + private def defaultListen: Future[(Address, Promise[ActorRef])] = { + actorPromise.future.onSuccess { + case actorRef: ActorRef ⇒ registry.registerTransport(this, actorRef) + } + Promise.successful((localAddress, actorPromise)).future + } + + private def defaultAssociate(remoteAddress: Address): Future[Status] = { + registry.transportFor(remoteAddress) match { + + case Some((remoteTransport, actor)) ⇒ + val (localHandle, remoteHandle) = createHandlePair(remoteTransport, remoteAddress) + + val bothSides: Future[(ActorRef, ActorRef)] = for ( + actor1 ← localHandle.readHandlerPromise.future; + actor2 ← remoteHandle.readHandlerPromise.future + ) yield (actor1, actor2) + + registry.registerHandlePair(localHandle.key, bothSides) + actor ! InboundAssociation(remoteHandle) + + Promise.successful(Ready(localHandle)).future + + case None ⇒ + Promise.successful(Fail(new IllegalArgumentException(s"No registered transport: $remoteAddress"))).future + } + } + + private def createHandlePair(remoteTransport: TestTransport, remoteAddress: Address): (TestAssociationHandle, TestAssociationHandle) = { + val localHandle = new TestAssociationHandle(localAddress, remoteAddress, this, inbound = false) + val remoteHandle = new TestAssociationHandle(remoteAddress, localAddress, remoteTransport, inbound = true) + + (localHandle, remoteHandle) + } + + private def defaultShutdown: Future[Unit] = Promise.successful(()).future + + /** + * The [[akka.remote.transport.TestTransport.SwitchableLoggedBehavior]] for the listen() method. + */ + val listenBehavior = new SwitchableLoggedBehavior[Unit, (Address, Promise[ActorRef])]( + (_) ⇒ defaultListen, + (_) ⇒ registry.logActivity(ListenAttempt(localAddress))) + + /** + * The [[akka.remote.transport.TestTransport.SwitchableLoggedBehavior]] for the associate() method. + */ + val associateBehavior = new SwitchableLoggedBehavior[Address, Status]( + defaultAssociate _, + (remoteAddress) ⇒ registry.logActivity(AssociateAttempt(localAddress, remoteAddress))) + + /** + * The [[akka.remote.transport.TestTransport.SwitchableLoggedBehavior]] for the shutdown() method. + */ + val shutdownBehavior = new SwitchableLoggedBehavior[Unit, Unit]( + (_) ⇒ defaultShutdown, + (_) ⇒ registry.logActivity(ShutdownAttempt(localAddress))) + + override def listen: Future[(Address, Promise[ActorRef])] = listenBehavior() + override def associate(remoteAddress: Address): Future[Status] = associateBehavior(remoteAddress) + override def shutdown(): Unit = shutdownBehavior() + + private def defaultWrite(params: (TestAssociationHandle, ByteString)): Future[Boolean] = { + registry.getRemoteReadHandlerFor(params._1) match { + case Some(futureActor) ⇒ + val writePromise = Promise[Boolean]() + futureActor.onSuccess { + case actor ⇒ actor ! InboundPayload(params._2); writePromise.success(true) + } + writePromise.future + case None ⇒ + Promise.failed(new IllegalStateException("No association present")).future + } + } + + private def defaultDisassociate(handle: TestAssociationHandle): Future[Unit] = { + registry.deregisterAssociation(handle.key).foreach { + case f: Future[(ActorRef, ActorRef)] ⇒ f.onSuccess { + case (handler1, handler2) ⇒ + val handler = if (handle.inbound) handler2 else handler1 + handler ! Disassociated + } + + } + Promise.successful(()).future + } + + /** + * The [[akka.remote.transport.TestTransport.SwitchableLoggedBehavior]] for the write() method on handles. All + * handle calls pass through this call. Please note, that write operations return a Boolean synchronously, so + * altering the behavior via pushDelayed will turn write to a blocking operation -- use of pushDelayed therefore + * is not recommended. + */ + val writeBehavior = new SwitchableLoggedBehavior[(TestAssociationHandle, ByteString), Boolean]( + defaultBehavior = { + defaultWrite _ + }, + logCallback = { + case (handle, payload) ⇒ + registry.logActivity(WriteAttempt(handle.localAddress, handle.remoteAddress, payload)) + }) + + /** + * The [[akka.remote.transport.TestTransport.SwitchableLoggedBehavior]] for the disassociate() method on handles. All + * handle calls pass through this call. + */ + val disassociateBehavior = new SwitchableLoggedBehavior[TestAssociationHandle, Unit]( + defaultBehavior = { + defaultDisassociate _ + }, + logCallback = { + (handle) ⇒ + registry.logActivity(DisassociateAttempt(handle.localAddress, handle.remoteAddress)) + }) + + private[akka] def write(handle: TestAssociationHandle, payload: ByteString): Boolean = + Await.result(writeBehavior((handle, payload)), 3 seconds) + + private[akka] def disassociate(handle: TestAssociationHandle): Unit = disassociateBehavior(handle) + + override def toString: String = s"TestTransport($localAddress)" + +} + +case class TestAssociationHandle( + localAddress: Address, + remoteAddress: Address, + transport: TestTransport, + inbound: Boolean) extends AssociationHandle { + + override val readHandlerPromise: Promise[ActorRef] = Promise() + + override def write(payload: ByteString): Boolean = transport.write(this, payload) + + override def disassociate(): Unit = transport.disassociate(this) + + /** + * Key used in [[akka.remote.transport.TestTransport.AssociationRegistry]] to identify associations. Contains an + * ordered pair of addresses, where the first element of the pair is always the initiator of the association. + */ + val key = if (!inbound) (localAddress, remoteAddress) else (remoteAddress, localAddress) +} diff --git a/akka-remote/src/test/scala/akka/remote/transport/TestTransportSpec.scala b/akka-remote/src/test/scala/akka/remote/transport/TestTransportSpec.scala new file mode 100644 index 0000000000..c4f51191c4 --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/transport/TestTransportSpec.scala @@ -0,0 +1,140 @@ +package akka.remote.transport + +import akka.testkit._ +import scala.concurrent._ +import akka.actor.Address +import akka.remote.transport.Transport._ +import akka.remote.transport.TestTransport._ +import akka.util.ByteString +import akka.remote.transport.AssociationHandle.{ Disassociated, InboundPayload } + +class TestTransportSpec extends AkkaSpec with DefaultTimeout with ImplicitSender { + + val addressA: Address = Address("akka", "testsytemA", "testhostA", 4321) + val addressB: Address = Address("akka", "testsytemB", "testhostB", 5432) + val nonExistingAddress = Address("akka", "nosystem", "nohost", 0) + + "TestTransport" must { + + "return an Address and promise when listen is called and log calls" in { + val registry = new AssociationRegistry + var transportA = new TestTransport(addressA, registry) + + val result = Await.result(transportA.listen, timeout.duration) + + result._1 must be(addressA) + result._2 must not be null + + registry.logSnapshot.exists { + case ListenAttempt(address) ⇒ address == addressA + case _ ⇒ false + } must be(true) + } + + "associate successfully with another TestTransport and log" in { + val registry = new AssociationRegistry + var transportA = new TestTransport(addressA, registry) + var transportB = new TestTransport(addressB, registry) + + // Must complete the returned promise to receive events + Await.result(transportA.listen, timeout.duration)._2.success(self) + Await.result(transportB.listen, timeout.duration)._2.success(self) + + awaitCond(registry.transportsReady(transportA, transportB)) + + transportA.associate(addressB) + expectMsgPF(timeout.duration, "Expect InboundAssociation from A") { + case InboundAssociation(handle) if handle.remoteAddress == addressA ⇒ + } + + registry.logSnapshot.contains(AssociateAttempt(addressA, addressB)) must be(true) + } + + "fail to associate with nonexisting address" in { + val registry = new AssociationRegistry + var transportA = new TestTransport(addressA, registry) + + Await.result(transportA.listen, timeout.duration)._2.success(self) + Await.result(transportA.associate(nonExistingAddress), timeout.duration) match { + case Fail(_) ⇒ + case _ ⇒ fail() + } + } + + "emulate sending PDUs and logs write" in { + val registry = new AssociationRegistry + var transportA = new TestTransport(addressA, registry) + var transportB = new TestTransport(addressB, registry) + + Await.result(transportA.listen, timeout.duration)._2.success(self) + Await.result(transportB.listen, timeout.duration)._2.success(self) + + awaitCond(registry.transportsReady(transportA, transportB)) + + val associate: Future[Status] = transportA.associate(addressB) + val handleB = expectMsgPF(timeout.duration, "Expect InboundAssociation from A") { + case InboundAssociation(handle) if handle.remoteAddress == addressA ⇒ handle + } + + val Ready(handleA) = Await.result(associate, timeout.duration) + + // Initialize handles + handleA.readHandlerPromise.success(self) + handleB.readHandlerPromise.success(self) + + val akkaPDU = ByteString("AkkaPDU") + + awaitCond(registry.existsAssociation(addressA, addressB)) + + handleA.write(akkaPDU) + expectMsgPF(timeout.duration, "Expect InboundPayload from A") { + case InboundPayload(payload) if payload == akkaPDU ⇒ + } + + registry.logSnapshot.exists { + case WriteAttempt(sender, recipient, payload) ⇒ + sender == addressA && recipient == addressB && payload == akkaPDU + case _ ⇒ false + } must be(true) + } + + "emulate disassociation and log it" in { + val registry = new AssociationRegistry + var transportA = new TestTransport(addressA, registry) + var transportB = new TestTransport(addressB, registry) + + Await.result(transportA.listen, timeout.duration)._2.success(self) + Await.result(transportB.listen, timeout.duration)._2.success(self) + + awaitCond(registry.transportsReady(transportA, transportB)) + + val associate: Future[Status] = transportA.associate(addressB) + val handleB: AssociationHandle = expectMsgPF(timeout.duration, "Expect InboundAssociation from A") { + case InboundAssociation(handle) if handle.remoteAddress == addressA ⇒ handle + } + + val Ready(handleA) = Await.result(associate, timeout.duration) + + // Initialize handles + handleA.readHandlerPromise.success(self) + handleB.readHandlerPromise.success(self) + + awaitCond(registry.existsAssociation(addressA, addressB)) + + handleA.disassociate() + + expectMsgPF(timeout.duration) { + case Disassociated ⇒ + } + + awaitCond(!registry.existsAssociation(addressA, addressB)) + + registry.logSnapshot exists { + case DisassociateAttempt(requester, remote) if requester == addressA && remote == addressB ⇒ true + case _ ⇒ false + } must be(true) + } + + } + +}