diff --git a/akka-actor/src/main/scala/akka/remoteinterface/RemoteInterface.scala b/akka-actor/src/main/scala/akka/remoteinterface/RemoteInterface.scala index 7764469b62..7a50a3556c 100644 --- a/akka-actor/src/main/scala/akka/remoteinterface/RemoteInterface.scala +++ b/akka-actor/src/main/scala/akka/remoteinterface/RemoteInterface.scala @@ -379,31 +379,42 @@ trait RemoteClientModule extends RemoteModule { self: RemoteModule => def clientManagedActorOf(factory: () => Actor, host: String, port: Int): ActorRef + + /** + * Clean-up all open connections. + */ + def shutdownClientModule: Unit + + /** + * Shuts down a specific client connected to the supplied remote address returns true if successful + */ + def shutdownClientConnection(address: InetSocketAddress): Boolean + + /** + * Restarts a specific client connected to the supplied remote address, but only if the client is not shut down + */ + def restartClientConnection(address: InetSocketAddress): Boolean + /** Methods that needs to be implemented by a transport **/ - protected[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, host: String, port: Int, loader: Option[ClassLoader]): T + protected[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, host: String, port: Int, loader: Option[ClassLoader]): T - protected[akka] def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]): ActorRef + protected[akka] def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]): ActorRef - protected[akka] def send[T](message: Any, - senderOption: Option[ActorRef], - senderFuture: Option[CompletableFuture[T]], - remoteAddress: InetSocketAddress, - timeout: Long, - isOneWay: Boolean, - actorRef: ActorRef, - typedActorInfo: Option[Tuple2[String, String]], - actorType: ActorType, - loader: Option[ClassLoader]): Option[CompletableFuture[T]] + protected[akka] def send[T](message: Any, + senderOption: Option[ActorRef], + senderFuture: Option[CompletableFuture[T]], + remoteAddress: InetSocketAddress, + timeout: Long, + isOneWay: Boolean, + actorRef: ActorRef, + typedActorInfo: Option[Tuple2[String, String]], + actorType: ActorType, + loader: Option[ClassLoader]): Option[CompletableFuture[T]] - private[akka] def registerSupervisorForActor(actorRef: ActorRef): ActorRef + private[akka] def registerSupervisorForActor(actorRef: ActorRef): ActorRef - private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef - - /** - * Clean-up all open connections. - */ - def shutdownClientModule: Unit + private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef private[akka] def registerClientManagedActor(hostname: String, port: Int, uuid: Uuid): Unit diff --git a/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala index cd81c337c0..3b98836f5f 100644 --- a/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala @@ -63,28 +63,37 @@ trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagem clientFor(remoteAddress, loader).send[T](message, senderOption, senderFuture, remoteAddress, timeout, isOneWay, actorRef, typedActorInfo, actorType) private[akka] def clientFor( - address: InetSocketAddress, loader: Option[ClassLoader]): RemoteClient = synchronized { //TODO: REVIST: synchronized here seems bottlenecky - val hostname = address.getHostName - val port = address.getPort - val hash = hostname + ':' + port + address: InetSocketAddress, loader: Option[ClassLoader]): RemoteClient = synchronized { //TODO: REVISIT: synchronized here seems bottlenecky + val key = makeKey(address) loader.foreach(MessageSerializer.setClassLoader(_))//TODO: REVISIT: THIS SMELLS FUNNY - if (remoteClients.contains(hash)) remoteClients(hash) - else { - val client = new RemoteClient(this, new InetSocketAddress(hostname, port), loader, self.notifyListeners _) - client.connect - remoteClients += hash -> client - client + remoteClients.get(key) match { + case Some(client) => client + case None => + val client = new RemoteClient(this, address, loader, self.notifyListeners _) + client.connect() + remoteClients += key -> client + client } } - def shutdownClientFor(address: InetSocketAddress) = synchronized { - val hostname = address.getHostName - val port = address.getPort - val hash = hostname + ':' + port - if (remoteClients.contains(hash)) { - val client = remoteClients(hash) - client.shutdown - remoteClients -= hash + private def makeKey(a: InetSocketAddress): String = a match { + case null => null + case address => address.getHostName + ':' + address.getPort + } + + def shutdownClientConnection(address: InetSocketAddress): Boolean = synchronized { + remoteClients.remove(makeKey(address)) match { + case Some(client) => + client.shutdown + true + case None => false + } + } + + def restartClientConnection(address: InetSocketAddress): Boolean = synchronized { + remoteClients.get(makeKey(address)) match { + case Some(client) => client.connect(reconnectIfAlreadyConnected = true) + case None => false } } @@ -109,7 +118,7 @@ trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagem private[akka] def unregisterClientManagedActor(hostname: String, port: Int, uuid: Uuid) = synchronized { val set = actorsFor(Address(hostname, port)) set -= uuid - if (set.isEmpty) shutdownClientFor(new InetSocketAddress(hostname, port)) + if (set.isEmpty) shutdownClientConnection(new InetSocketAddress(hostname, port)) } private[akka] def actorsFor(remoteServerAddress: Address): HashSet[Uuid] = { @@ -168,28 +177,48 @@ class RemoteClient private[akka] ( @volatile private var reconnectionTimeWindowStart = 0L - def connect = runSwitch switchOn { - openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName) - timer = new HashedWheelTimer + def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = { + runSwitch switchOn { + openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName) + timer = new HashedWheelTimer - bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) - bootstrap.setPipelineFactory(new RemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this)) - bootstrap.setOption("tcpNoDelay", true) - bootstrap.setOption("keepAlive", true) + bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) + bootstrap.setPipelineFactory(new RemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this)) + bootstrap.setOption("tcpNoDelay", true) + bootstrap.setOption("keepAlive", true) - log.slf4j.info("Starting remote client connection to [{}]", remoteAddress) + log.slf4j.info("Starting remote client connection to [{}]", remoteAddress) - // Wait until the connection attempt succeeds or fails. - connection = bootstrap.connect(remoteAddress) - val channel = connection.awaitUninterruptibly.getChannel - openChannels.add(channel) + // Wait until the connection attempt succeeds or fails. + connection = bootstrap.connect(remoteAddress) + openChannels.add(connection.awaitUninterruptibly.getChannel) - if (!connection.isSuccess) { - notifyListeners(RemoteClientError(connection.getCause, module, remoteAddress)) - log.slf4j.error("Remote client connection to [{}] has failed", remoteAddress) - log.slf4j.debug("Remote client connection failed", connection.getCause) + if (!connection.isSuccess) { + notifyListeners(RemoteClientError(connection.getCause, module, remoteAddress)) + log.slf4j.error("Remote client connection to [{}] has failed", remoteAddress) + log.slf4j.debug("Remote client connection failed", connection.getCause) + false + } else { + notifyListeners(RemoteClientStarted(module, remoteAddress)) + true + } + } match { + case true => true + case false if reconnectIfAlreadyConnected => + isAuthenticated.set(false) + log.slf4j.debug("Remote client reconnecting to [{}]", remoteAddress) + openChannels.remove(connection.getChannel) + connection.getChannel.close + connection = bootstrap.connect(remoteAddress) + openChannels.add(connection.awaitUninterruptibly.getChannel) // Wait until the connection attempt succeeds or fails. + if (!connection.isSuccess) { + notifyListeners(RemoteClientError(connection.getCause, module, remoteAddress)) + log.slf4j.error("Reconnection to [{}] has failed", remoteAddress) + log.slf4j.debug("Reconnection failed", connection.getCause) + false + } else true + case false => false } - notifyListeners(RemoteClientStarted(module, remoteAddress)) } def shutdown = runSwitch switchOff { @@ -386,15 +415,7 @@ class RemoteClientHandler( timer.newTimeout(new TimerTask() { def run(timeout: Timeout) = { client.openChannels.remove(event.getChannel) - client.isAuthenticated.set(false) - log.slf4j.debug("Remote client reconnecting to [{}]", remoteAddress) - client.connection = bootstrap.connect(remoteAddress) - client.connection.awaitUninterruptibly // Wait until the connection attempt succeeds or fails. - if (!client.connection.isSuccess) { - client.notifyListeners(RemoteClientError(client.connection.getCause, client.module, client.remoteAddress)) - log.slf4j.error("Reconnection to [{}] has failed", remoteAddress) - log.slf4j.debug("Reconnection failed", client.connection.getCause) - } + client.connect(reconnectIfAlreadyConnected = true) } }, RemoteClient.RECONNECT_DELAY.toMillis, TimeUnit.MILLISECONDS) } else spawn { client.shutdown }