Added clean automatic shutdown of RemoteClient, based on reference counting + fixed bug in shutdown of RemoteClient
This commit is contained in:
parent
1355dd2411
commit
675101cc82
7 changed files with 108 additions and 34 deletions
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
package se.scalablesolutions.akka.remote
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import se.scalablesolutions.akka.remote.protobuf.RemoteProtocol.{RemoteRequest, RemoteReply}
|
||||
import se.scalablesolutions.akka.actor.{Exit, Actor}
|
||||
import se.scalablesolutions.akka.dispatch.{DefaultCompletableFutureResult, CompletableFutureResult}
|
||||
|
|
@ -13,6 +11,7 @@ import se.scalablesolutions.akka.util.{UUID, Logging}
|
|||
import se.scalablesolutions.akka.Config.config
|
||||
|
||||
import org.jboss.netty.channel._
|
||||
import group.DefaultChannelGroup
|
||||
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
|
||||
import org.jboss.netty.bootstrap.ClientBootstrap
|
||||
import org.jboss.netty.handler.codec.frame.{LengthFieldBasedFrameDecoder, LengthFieldPrepender}
|
||||
|
|
@ -25,6 +24,8 @@ import java.net.{SocketAddress, InetSocketAddress}
|
|||
import java.util.concurrent.{TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap}
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
import scala. collection.mutable.{HashSet, HashMap}
|
||||
|
||||
/**
|
||||
* @author <a href="http://jonasboner.com">Jonas Bonér</a>
|
||||
*/
|
||||
|
|
@ -41,27 +42,62 @@ object RemoteClient extends Logging {
|
|||
val READ_TIMEOUT = config.getInt("akka.remote.client.read-timeout", 10000)
|
||||
val RECONNECT_DELAY = config.getInt("akka.remote.client.reconnect-delay", 5000)
|
||||
|
||||
private val clients = new HashMap[String, RemoteClient]
|
||||
private val remoteClients = new HashMap[String, RemoteClient]
|
||||
private val remoteActors = new HashMap[RemoteServer.Address, HashSet[String]]
|
||||
|
||||
def clientFor(hostname: String, port: Int): RemoteClient = clientFor(new InetSocketAddress(hostname, port))
|
||||
|
||||
def clientFor(address: InetSocketAddress): RemoteClient = synchronized {
|
||||
val hostname = address.getHostName
|
||||
val port = address.getPort
|
||||
val hash = hostname + ':' + port
|
||||
if (clients.contains(hash)) clients(hash)
|
||||
if (remoteClients.contains(hash)) remoteClients(hash)
|
||||
else {
|
||||
val client = new RemoteClient(hostname, port)
|
||||
client.connect
|
||||
clients += hash -> client
|
||||
remoteClients += hash -> client
|
||||
client
|
||||
}
|
||||
}
|
||||
|
||||
def shutdownClientFor(address: InetSocketAddress) = synchronized {
|
||||
val hostname = address.getHostName
|
||||
val port = address.getPort
|
||||
val hash = hostname + ':' + port
|
||||
if (remoteClients.contains(hash)) {
|
||||
val client = remoteClients(hash)
|
||||
client.shutdown
|
||||
remoteClients - hash
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean-up all open connections.
|
||||
*/
|
||||
def shutdownAll() = synchronized {
|
||||
clients.foreach({case (addr, client) => client.shutdown})
|
||||
clients.clear
|
||||
def shutdownAll = synchronized {
|
||||
remoteClients.foreach({case (addr, client) => client.shutdown})
|
||||
remoteClients.clear
|
||||
}
|
||||
|
||||
private[akka] def register(hostname: String, port: Int, uuid: String) = synchronized {
|
||||
actorsFor(RemoteServer.Address(hostname, port)) + uuid
|
||||
}
|
||||
|
||||
// TODO: add RemoteClient.unregister for ActiveObject, but first need a @shutdown callback
|
||||
private[akka] def unregister(hostname: String, port: Int, uuid: String) = synchronized {
|
||||
val set = actorsFor(RemoteServer.Address(hostname, port))
|
||||
set - uuid
|
||||
if (set.isEmpty) shutdownClientFor(new InetSocketAddress(hostname, port))
|
||||
}
|
||||
|
||||
private[akka] def actorsFor(remoteServerAddress: RemoteServer.Address): HashSet[String] = {
|
||||
val set = remoteActors.get(remoteServerAddress)
|
||||
if (set.isDefined && (set.get ne null)) set.get
|
||||
else {
|
||||
val remoteActorSet = new HashSet[String]
|
||||
remoteActors.put(remoteServerAddress, remoteActorSet)
|
||||
remoteActorSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -69,9 +105,9 @@ object RemoteClient extends Logging {
|
|||
* @author <a href="http://jonasboner.com">Jonas Bonér</a>
|
||||
*/
|
||||
class RemoteClient(hostname: String, port: Int) extends Logging {
|
||||
val name = "RemoteClient@" + hostname
|
||||
|
||||
@volatile private var isRunning = false
|
||||
val name = "RemoteClient@" + hostname + "::" + port
|
||||
|
||||
@volatile private[remote] var isRunning = false
|
||||
private val futures = new ConcurrentHashMap[Long, CompletableFutureResult]
|
||||
private val supervisors = new ConcurrentHashMap[String, Actor]
|
||||
|
||||
|
|
@ -80,6 +116,7 @@ class RemoteClient(hostname: String, port: Int) extends Logging {
|
|||
Executors.newCachedThreadPool)
|
||||
|
||||
private val bootstrap = new ClientBootstrap(channelFactory)
|
||||
private val openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName);
|
||||
|
||||
private val timer = new HashedWheelTimer
|
||||
private val remoteAddress = new InetSocketAddress(hostname, port)
|
||||
|
|
@ -93,20 +130,22 @@ class RemoteClient(hostname: String, port: Int) extends Logging {
|
|||
if (!isRunning) {
|
||||
connection = bootstrap.connect(remoteAddress)
|
||||
log.info("Starting remote client connection to [%s:%s]", hostname, port)
|
||||
|
||||
// Wait until the connection attempt succeeds or fails.
|
||||
connection.awaitUninterruptibly
|
||||
if (!connection.isSuccess) log.error(connection.getCause, "Remote connection to [%s:%s] has failed", hostname, port)
|
||||
val channel = connection.awaitUninterruptibly.getChannel
|
||||
openChannels.add(channel)
|
||||
if (!connection.isSuccess) log.error(connection.getCause, "Remote client connection to [%s:%s] has failed", hostname, port)
|
||||
isRunning = true
|
||||
}
|
||||
}
|
||||
|
||||
def shutdown = synchronized {
|
||||
if (!isRunning) {
|
||||
connection.getChannel.getCloseFuture.awaitUninterruptibly
|
||||
channelFactory.releaseExternalResources
|
||||
if (isRunning) {
|
||||
isRunning = false
|
||||
openChannels.close.awaitUninterruptibly
|
||||
bootstrap.releaseExternalResources
|
||||
timer.stop
|
||||
log.info("%s has been shut down", name)
|
||||
}
|
||||
timer.stop
|
||||
}
|
||||
|
||||
def send(request: RemoteRequest, senderFuture: Option[CompletableFutureResult]): Option[CompletableFutureResult] = if (isRunning) {
|
||||
|
|
@ -120,7 +159,7 @@ class RemoteClient(hostname: String, port: Int) extends Logging {
|
|||
futures.put(request.getId, futureResult)
|
||||
connection.getChannel.write(request)
|
||||
Some(futureResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else throw new IllegalStateException("Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.")
|
||||
|
||||
|
|
@ -131,14 +170,14 @@ class RemoteClient(hostname: String, port: Int) extends Logging {
|
|||
def deregisterSupervisorForActor(actor: Actor) =
|
||||
if (!actor._supervisor.isDefined) throw new IllegalStateException("Can't unregister supervisor for " + actor + " since it is not under supervision")
|
||||
else supervisors.remove(actor._supervisor.get.uuid)
|
||||
|
||||
|
||||
def deregisterSupervisorWithUuid(uuid: String) = supervisors.remove(uuid)
|
||||
}
|
||||
|
||||
/**
|
||||
* @author <a href="http://jonasboner.com">Jonas Bonér</a>
|
||||
*/
|
||||
class RemoteClientPipelineFactory(name: String,
|
||||
class RemoteClientPipelineFactory(name: String,
|
||||
futures: ConcurrentMap[Long, CompletableFutureResult],
|
||||
supervisors: ConcurrentMap[String, Actor],
|
||||
bootstrap: ClientBootstrap,
|
||||
|
|
@ -158,7 +197,7 @@ class RemoteClientPipelineFactory(name: String,
|
|||
}
|
||||
val remoteClient = new RemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client)
|
||||
|
||||
val stages: Array[ChannelHandler] =
|
||||
val stages: Array[ChannelHandler] =
|
||||
zipCodec.map(codec => Array(timeout, codec.decoder, lenDec, protobufDec, codec.encoder, lenPrep, protobufEnc, remoteClient))
|
||||
.getOrElse(Array(timeout, lenDec, protobufDec, lenPrep, protobufEnc, remoteClient))
|
||||
new StaticChannelPipeline(stages: _*)
|
||||
|
|
@ -214,9 +253,9 @@ class RemoteClientHandler(val name: String,
|
|||
log.error("Unexpected exception in remote client handler: %s", e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
|
||||
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = if (client.isRunning) {
|
||||
timer.newTimeout(new TimerTask() {
|
||||
def run(timeout: Timeout) = {
|
||||
log.debug("Remote client reconnecting to [%s]", remoteAddress)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue