diff --git a/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala index ca60b39ab3..4831b3fe94 100644 --- a/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/NettyRemoteSupport.scala @@ -31,21 +31,28 @@ import org.jboss.netty.handler.ssl.SslHandler import java.net.{ SocketAddress, InetSocketAddress } import java.util.concurrent.{ TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet } +import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.{ HashSet, HashMap } import scala.reflect.BeanProperty import java.lang.reflect.InvocationTargetException -import akka.actor. {ActorInitializationException, LocalActorRef, newUuid, ActorRegistry, Actor, RemoteActorRef, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage, uuidFrom, Uuid, Exit, LifeCycleMessage, ActorType => AkkaActorType} import java.util.concurrent.atomic. {AtomicReference, AtomicLong, AtomicBoolean} import akka.remoteinterface._ +import akka.actor. {Index, ActorInitializationException, LocalActorRef, newUuid, ActorRegistry, Actor, RemoteActorRef, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage, uuidFrom, Uuid, Exit, LifeCycleMessage, ActorType => AkkaActorType} + +trait NettyRemoteShared { + def registerPassiveClient(channel: Channel): Boolean + def deregisterPassiveClient(channel: Channel): Boolean +} /** * The RemoteClient object manages RemoteClient instances and gives you an API to lookup remote actor handles. * * @author Jonas Bonér */ -trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagement with Logging => +trait NettyRemoteClientModule extends RemoteClientModule with NettyRemoteShared { self: ListenerManagement with Logging => private val remoteClients = new HashMap[String, RemoteClient] - private val remoteActors = new HashMap[Address, HashSet[Uuid]] + private val remoteActors = new Index[Address, Uuid] + private val lock = new ReentrantReadWriteLock protected[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]): T = TypedActor.createProxyForRemoteActorRef(intfClass, RemoteActorRef(serviceId, implClassName, hostname, port, timeout, loader, AkkaActorType.TypedActor)) @@ -63,16 +70,76 @@ trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagem clientFor(remoteAddress, loader).send[T](message, senderOption, senderFuture, remoteAddress, timeout, isOneWay, actorRef, typedActorInfo, actorType) private[akka] def clientFor( - address: InetSocketAddress, loader: Option[ClassLoader]): RemoteClient = synchronized { //TODO: REVISIT: synchronized here seems bottlenecky - val key = makeKey(address) + address: InetSocketAddress, loader: Option[ClassLoader]): RemoteClient = { loader.foreach(MessageSerializer.setClassLoader(_))//TODO: REVISIT: THIS SMELLS FUNNY + + val key = makeKey(address) + lock.readLock.lock remoteClients.get(key) match { - case Some(client) => client + case Some(client) => try { client } finally { lock.readLock.unlock } case None => - val client = new RemoteClient(this, address, loader, self.notifyListeners _) - client.connect() - remoteClients += key -> client - client + lock.readLock.unlock + lock.writeLock.lock //Lock upgrade, not supported natively + try { + remoteClients.get(key) match { //Recheck for addition, race between upgrades + case Some(client) => client //If already populated by other writer + case None => //Populate map + val client = new ActiveRemoteClient(this, address, loader, self.notifyListeners _) + client.connect() + remoteClients += key -> client + client + } + } finally { lock.writeLock.unlock } + } + } + + /** + * This method is called by the server module to register passive clients + */ + def registerPassiveClient(channel: Channel): Boolean = { + val address = channel.getRemoteAddress.asInstanceOf[InetSocketAddress] + val key = makeKey(address) + lock.readLock.lock + remoteClients.get(key) match { + case Some(client) => try { false } finally { lock.readLock.unlock } + case None => + lock.readLock.unlock + lock.writeLock.lock //Lock upgrade, not supported natively + try { + remoteClients.get(key) match { + case Some(client) => false + case None => + val client = new PassiveRemoteClient(this, address, channel, self.notifyListeners _ ) + client.connect() + remoteClients.put(key, client) + true + } + } finally { lock.writeLock.unlock } + } + } + + /** + * This method is called by the server module to deregister passive clients + */ + def deregisterPassiveClient(channel: Channel): Boolean = { + val address = channel.getRemoteAddress.asInstanceOf[InetSocketAddress] + val key = makeKey(address) + lock.readLock.lock + remoteClients.get(key) match { + case Some(client: PassiveRemoteClient) => + lock.readLock.unlock + lock.writeLock.lock //Lock upgrade, not supported natively + try { + remoteClients.get(key) match { + case Some(client: ActiveRemoteClient) => false + case None => false + case Some(client: PassiveRemoteClient) => + remoteClients.remove(key) + true + } + } finally { lock.writeLock.unlock } + //Otherwise, unlock the readlock and return false + case _ => try { false } finally { lock.readLock.unlock } } } @@ -81,52 +148,70 @@ trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagem case address => address.getHostName + ':' + address.getPort } - def shutdownClientConnection(address: InetSocketAddress): Boolean = synchronized { - remoteClients.remove(makeKey(address)) match { - case Some(client) => client.shutdown - case None => false + def shutdownClientConnection(address: InetSocketAddress): Boolean = { + lock.writeLock.lock + try { + remoteClients.remove(makeKey(address)) match { + case Some(client) => client.shutdown + case None => false + } + } finally { + lock.writeLock.unlock } } - def restartClientConnection(address: InetSocketAddress): Boolean = synchronized { - remoteClients.get(makeKey(address)) match { - case Some(client) => client.connect(reconnectIfAlreadyConnected = true) - case None => false + def restartClientConnection(address: InetSocketAddress): Boolean = { + lock.readLock.lock + try { + remoteClients.get(makeKey(address)) match { + case Some(client) => client.connect(reconnectIfAlreadyConnected = true) + case None => false + } + } finally { + lock.readLock.unlock } } private[akka] def registerSupervisorForActor(actorRef: ActorRef): ActorRef = clientFor(actorRef.homeAddress.get, None).registerSupervisorForActor(actorRef) - private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef = - clientFor(actorRef.homeAddress.get, None).deregisterSupervisorForActor(actorRef) + private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef = { + val key = makeKey(actorRef.homeAddress.get) + lock.readLock.lock //TODO: perhaps use writelock here + try { + remoteClients.get(key) match { + case Some(client) => client.deregisterSupervisorForActor(actorRef) + case None => actorRef + } + } finally { + lock.readLock.unlock + } + } /** * Clean-up all open connections. */ - def shutdownClientModule = synchronized { + def shutdownClientModule = { + shutdownRemoteClients + //TODO: Should we empty our remoteActors too? + //remoteActors.clear + } + + def shutdownRemoteClients = try { + lock.writeLock.lock remoteClients.foreach({ case (addr, client) => client.shutdown }) remoteClients.clear + } finally { + lock.writeLock.unlock } - def registerClientManagedActor(hostname: String, port: Int, uuid: Uuid) = synchronized { - actorsFor(Address(hostname, port)) += uuid + def registerClientManagedActor(hostname: String, port: Int, uuid: Uuid) = { + remoteActors.put(Address(hostname, port), uuid) } - private[akka] def unregisterClientManagedActor(hostname: String, port: Int, uuid: Uuid) = synchronized { - val set = actorsFor(Address(hostname, port)) - set -= uuid - if (set.isEmpty) shutdownClientConnection(new InetSocketAddress(hostname, port)) - } - - private[akka] def actorsFor(remoteServerAddress: Address): HashSet[Uuid] = { - val set = remoteActors.get(remoteServerAddress) - if (set.isDefined && (set.get ne null)) set.get - else { - val remoteActorSet = new HashSet[Uuid] - remoteActors.put(remoteServerAddress, remoteActorSet) - remoteActorSet - } + private[akka] def unregisterClientManagedActor(hostname: String, port: Int, uuid: Uuid) = { + remoteActors.remove(Address(hostname,port), uuid) + //TODO: should the connection be closed when the last actor deregisters? } } @@ -135,45 +220,152 @@ object RemoteClient { case "" => None case cookie => Some(cookie) } - + val RECONNECTION_TIME_WINDOW = Duration(config.getInt("akka.remote.client.reconnection-time-window", 600), TIME_UNIT).toMillis val READ_TIMEOUT = Duration(config.getInt("akka.remote.client.read-timeout", 1), TIME_UNIT) val RECONNECT_DELAY = Duration(config.getInt("akka.remote.client.reconnect-delay", 5), TIME_UNIT) val MESSAGE_FRAME_SIZE = config.getInt("akka.remote.client.message-frame-size", 1048576) } -/** - * RemoteClient represents a connection to a RemoteServer. Is used to send messages to remote actors on the RemoteServer. - * - * @author Jonas Bonér - */ -class RemoteClient private[akka] ( + +abstract class RemoteClient private[akka] ( val module: NettyRemoteClientModule, - val remoteAddress: InetSocketAddress, - val loader: Option[ClassLoader] = None, - val notifyListeners: (=> Any) => Unit) extends Logging { - val name = "RemoteClient@" + remoteAddress.getHostName + "::" + remoteAddress.getPort + val remoteAddress: InetSocketAddress) extends Logging { - //FIXME Should these be clear:ed on postStop? - private val futures = new ConcurrentHashMap[Uuid, CompletableFuture[_]] - private val supervisors = new ConcurrentHashMap[Uuid, ActorRef] + val name = this.getClass.getSimpleName + "@" + remoteAddress.getHostName + "::" + remoteAddress.getPort - //FIXME rewrite to a wrapper object (minimize volatile access and maximize encapsulation) - @volatile - private var bootstrap: ClientBootstrap = _ - @volatile - private[remote] var connection: ChannelFuture = _ - @volatile - private[remote] var openChannels: DefaultChannelGroup = _ - @volatile - private var timer: HashedWheelTimer = _ + protected val futures = new ConcurrentHashMap[Uuid, CompletableFuture[_]] + protected val supervisors = new ConcurrentHashMap[Uuid, ActorRef] private[remote] val runSwitch = new Switch() private[remote] val isAuthenticated = new AtomicBoolean(false) private[remote] def isRunning = runSwitch.isOn - private val reconnectionTimeWindow = Duration(config.getInt( - "akka.remote.client.reconnection-time-window", 600), TIME_UNIT).toMillis - @volatile - private var reconnectionTimeWindowStart = 0L + protected def notifyListeners(msg: => Any); Unit + protected def currentChannel: Channel + + def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean + def shutdown: Boolean + + def send[T]( + message: Any, + senderOption: Option[ActorRef], + senderFuture: Option[CompletableFuture[T]], + remoteAddress: InetSocketAddress, + timeout: Long, + isOneWay: Boolean, + actorRef: ActorRef, + typedActorInfo: Option[Tuple2[String, String]], + actorType: AkkaActorType): Option[CompletableFuture[T]] = { + send(createRemoteMessageProtocolBuilder( + Some(actorRef), + Left(actorRef.uuid), + actorRef.id, + actorRef.actorClassName, + actorRef.timeout, + Left(message), + isOneWay, + senderOption, + typedActorInfo, + actorType, + if (isAuthenticated.compareAndSet(false, true)) RemoteClient.SECURE_COOKIE else None + ).build, senderFuture) + } + + def send[T]( + request: RemoteMessageProtocol, + senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = { + log.slf4j.debug("sending message: {} has future {}", request, senderFuture) + if (isRunning) { + if (request.getOneWay) { + currentChannel.write(request).addListener(new ChannelFutureListener { + def operationComplete(future: ChannelFuture) { + if (future.isCancelled) { + //We don't care about that right now + } else if (!future.isSuccess) { + notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) + } + } + }) + None + } else { + val futureResult = if (senderFuture.isDefined) senderFuture.get + else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) + + currentChannel.write(request).addListener(new ChannelFutureListener { + def operationComplete(future: ChannelFuture) { + if (future.isCancelled) { + //We don't care about that right now + } else if (!future.isSuccess) { + notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) + } else { + val futureUuid = uuidFrom(request.getUuid.getHigh, request.getUuid.getLow) + futures.put(futureUuid, futureResult) + } + } + }) + Some(futureResult) + } + } else { + val exception = new RemoteClientException("Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", module, remoteAddress) + notifyListeners(RemoteClientError(exception, module, remoteAddress)) + throw exception + } + } + + private[akka] def registerSupervisorForActor(actorRef: ActorRef): ActorRef = + if (!actorRef.supervisor.isDefined) throw new IllegalActorStateException( + "Can't register supervisor for " + actorRef + " since it is not under supervision") + else supervisors.putIfAbsent(actorRef.supervisor.get.uuid, actorRef) + + private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef = + if (!actorRef.supervisor.isDefined) throw new IllegalActorStateException( + "Can't unregister supervisor for " + actorRef + " since it is not under supervision") + else supervisors.remove(actorRef.supervisor.get.uuid) +} + +/** + * PassiveRemoteClient reuses an incoming connection + */ +class PassiveRemoteClient(module: NettyRemoteClientModule, + remoteAddress: InetSocketAddress, + val currentChannel : Channel, + notifyListenersFun: (=> Any) => Unit) extends RemoteClient(module, remoteAddress) { + def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = { //Cannot reconnect, it's passive. + runSwitch.switchOn { + notifyListeners(RemoteClientStarted(module, remoteAddress)) + } + false + } + + def shutdown = runSwitch switchOff { + log.slf4j.info("Shutting down {}", name) + notifyListeners(RemoteClientShutdown(module, remoteAddress)) + //try { currentChannel.close } catch { case _ => } //TODO: Add failure notification when currentchannel gets shut down? + log.slf4j.info("{} has been shut down", name) + } + + def notifyListeners(msg: => Any) = notifyListenersFun(msg) +} + +/** + * RemoteClient represents a connection to a RemoteServer. Is used to send messages to remote actors on the RemoteServer. + * + * @author Jonas Bonér + */ +class ActiveRemoteClient private[akka] ( + module: NettyRemoteClientModule, + remoteAddress: InetSocketAddress, + val loader: Option[ClassLoader] = None, + notifyListenersFun: (=> Any) => Unit) extends RemoteClient(module, remoteAddress) { + import RemoteClient._ + //FIXME rewrite to a wrapper object (minimize volatile access and maximize encapsulation) + @volatile private var bootstrap: ClientBootstrap = _ + @volatile private[remote] var connection: ChannelFuture = _ + @volatile private[remote] var openChannels: DefaultChannelGroup = _ + @volatile private var timer: HashedWheelTimer = _ + @volatile private var reconnectionTimeWindowStart = 0L + + def notifyListeners(msg: => Any): Unit = notifyListenersFun(msg) + def currentChannel = connection.getChannel def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = { runSwitch switchOn { @@ -181,7 +373,7 @@ class RemoteClient private[akka] ( timer = new HashedWheelTimer bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) - bootstrap.setPipelineFactory(new RemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this)) + bootstrap.setPipelineFactory(new ActiveRemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this)) bootstrap.setOption("tcpNoDelay", true) bootstrap.setOption("keepAlive", true) @@ -232,88 +424,12 @@ class RemoteClient private[akka] ( log.slf4j.info("{} has been shut down", name) } - def send[T]( - message: Any, - senderOption: Option[ActorRef], - senderFuture: Option[CompletableFuture[T]], - remoteAddress: InetSocketAddress, - timeout: Long, - isOneWay: Boolean, - actorRef: ActorRef, - typedActorInfo: Option[Tuple2[String, String]], - actorType: AkkaActorType): Option[CompletableFuture[T]] = { - send(createRemoteMessageProtocolBuilder( - Some(actorRef), - Left(actorRef.uuid), - actorRef.id, - actorRef.actorClassName, - actorRef.timeout, - Left(message), - isOneWay, - senderOption, - typedActorInfo, - actorType, - if (isAuthenticated.compareAndSet(false, true)) RemoteClient.SECURE_COOKIE else None - ).build, senderFuture) - } - - def send[T]( - request: RemoteMessageProtocol, - senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = { - log.slf4j.debug("sending message: {} has future {}", request, senderFuture) - if (isRunning) { - if (request.getOneWay) { - connection.getChannel.write(request).addListener(new ChannelFutureListener { - def operationComplete(future: ChannelFuture) { - if (future.isCancelled) { - //We don't care about that right now - } else if (!future.isSuccess) { - notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) - } - } - }) - None - } else { - val futureResult = if (senderFuture.isDefined) senderFuture.get - else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) - - connection.getChannel.write(request).addListener(new ChannelFutureListener { - def operationComplete(future: ChannelFuture) { - if (future.isCancelled) { - //We don't care about that right now - } else if (!future.isSuccess) { - notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) - } else { - val futureUuid = uuidFrom(request.getUuid.getHigh, request.getUuid.getLow) - futures.put(futureUuid, futureResult) - } - } - }) - Some(futureResult) - } - } else { - val exception = new RemoteClientException("Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", module, remoteAddress) - notifyListeners(RemoteClientError(exception, module, remoteAddress)) - throw exception - } - } - - private[akka] def registerSupervisorForActor(actorRef: ActorRef): ActorRef = - if (!actorRef.supervisor.isDefined) throw new IllegalActorStateException( - "Can't register supervisor for " + actorRef + " since it is not under supervision") - else supervisors.putIfAbsent(actorRef.supervisor.get.uuid, actorRef) - - private[akka] def deregisterSupervisorForActor(actorRef: ActorRef): ActorRef = - if (!actorRef.supervisor.isDefined) throw new IllegalActorStateException( - "Can't unregister supervisor for " + actorRef + " since it is not under supervision") - else supervisors.remove(actorRef.supervisor.get.uuid) - private[akka] def isWithinReconnectionTimeWindow: Boolean = { if (reconnectionTimeWindowStart == 0L) { reconnectionTimeWindowStart = System.currentTimeMillis true } else { - val timeLeft = reconnectionTimeWindow - (System.currentTimeMillis - reconnectionTimeWindowStart) + val timeLeft = RECONNECTION_TIME_WINDOW - (System.currentTimeMillis - reconnectionTimeWindowStart) if (timeLeft > 0) { log.slf4j.info("Will try to reconnect to remote server for another [{}] milliseconds", timeLeft) true @@ -327,14 +443,14 @@ class RemoteClient private[akka] ( /** * @author Jonas Bonér */ -class RemoteClientPipelineFactory( +class ActiveRemoteClientPipelineFactory( name: String, futures: ConcurrentMap[Uuid, CompletableFuture[_]], supervisors: ConcurrentMap[Uuid, ActorRef], bootstrap: ClientBootstrap, remoteAddress: SocketAddress, timer: HashedWheelTimer, - client: RemoteClient) extends ChannelPipelineFactory { + client: ActiveRemoteClient) extends ChannelPipelineFactory { def getPipeline: ChannelPipeline = { def join(ch: ChannelHandler*) = Array[ChannelHandler](ch: _*) @@ -357,7 +473,7 @@ class RemoteClientPipelineFactory( case _ => (join(), join()) } - val remoteClient = new RemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client) + val remoteClient = new ActiveRemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client) val stages = ssl ++ join(timeout) ++ dec ++ join(lenDec, protobufDec) ++ enc ++ join(lenPrep, protobufEnc, remoteClient) new StaticChannelPipeline(stages: _*) } @@ -367,14 +483,14 @@ class RemoteClientPipelineFactory( * @author Jonas Bonér */ @ChannelHandler.Sharable -class RemoteClientHandler( +class ActiveRemoteClientHandler( val name: String, val futures: ConcurrentMap[Uuid, CompletableFuture[_]], val supervisors: ConcurrentMap[Uuid, ActorRef], val bootstrap: ClientBootstrap, val remoteAddress: SocketAddress, val timer: HashedWheelTimer, - val client: RemoteClient) + val client: ActiveRemoteClient) extends SimpleChannelUpstreamHandler with Logging { override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = { @@ -610,7 +726,7 @@ class NettyRemoteServer(serverModule: NettyRemoteServerModule, val host: String, } } -trait NettyRemoteServerModule extends RemoteServerModule { self: RemoteModule => +trait NettyRemoteServerModule extends RemoteServerModule with NettyRemoteShared { self: RemoteModule => import RemoteServer._ private[akka] val currentServer = new AtomicReference[Option[NettyRemoteServer]](None) @@ -877,7 +993,10 @@ class RemoteServerHandler( } else future.getChannel.close } }) - } else server.notifyListeners(RemoteServerClientConnected(server, clientAddress)) + } else { + server.registerPassiveClient(ctx.getChannel) + server.notifyListeners(RemoteServerClientConnected(server, clientAddress)) + } if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication } @@ -900,6 +1019,7 @@ class RemoteServerHandler( TypedActor.stop(channelTypedActorsIterator.nextElement) } } + server.deregisterPassiveClient(ctx.getChannel) server.notifyListeners(RemoteServerClientDisconnected(server, clientAddress)) }