Merge with master

This commit is contained in:
Viktor Klang 2010-10-28 14:03:53 +02:00
commit 41b5fd2de8
26 changed files with 936 additions and 501 deletions

View file

@ -31,6 +31,7 @@ import org.jboss.netty.handler.ssl.SslHandler
import scala.collection.mutable.Map
import scala.reflect.BeanProperty
import akka.config.ConfigurationException
/**
* Use this object if you need a single remote server on a specific node.
@ -61,21 +62,30 @@ import scala.reflect.BeanProperty
object RemoteNode extends RemoteServer
/**
* For internal use only.
* Holds configuration variables, remote actors, remote typed actors and remote servers.
* For internal use only. Holds configuration variables, remote actors, remote typed actors and remote servers.
*
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
object
RemoteServer {
object RemoteServer {
val UUID_PREFIX = "uuid:"
val HOSTNAME = config.getString("akka.remote.server.hostname", "localhost")
val PORT = config.getInt("akka.remote.server.port", 9999)
val SECURE_COOKIE: Option[String] = {
val cookie = config.getString("akka.remote.secure-cookie", "")
if (cookie == "") None
else Some(cookie)
}
val REQUIRE_COOKIE = {
val requireCookie = config.getBool("akka.remote.server.require-cookie", true)
if (RemoteServer.SECURE_COOKIE.isEmpty) throw new ConfigurationException(
"Configuration option 'akka.remote.server.require-cookie' is turned on but no secure cookie is defined in 'akka.remote.secure-cookie'.")
requireCookie
}
val HOSTNAME = config.getString("akka.remote.server.hostname", "localhost")
val PORT = config.getInt("akka.remote.server.port", 9999)
val CONNECTION_TIMEOUT_MILLIS = Duration(config.getInt("akka.remote.server.connection-timeout", 1), TIME_UNIT)
val COMPRESSION_SCHEME = config.getString("akka.remote.compression-scheme", "zlib")
val ZLIB_COMPRESSION_LEVEL = {
val COMPRESSION_SCHEME = config.getString("akka.remote.compression-scheme", "zlib")
val ZLIB_COMPRESSION_LEVEL = {
val level = config.getInt("akka.remote.zlib-compression-level", 6)
if (level < 1 && level > 9) throw new IllegalArgumentException(
"zlib compression level has to be within 1-9, with 1 being fastest and 9 being the most compressed")
@ -128,7 +138,6 @@ RemoteServer {
private[akka] def unregister(hostname: String, port: Int) = guard.withWriteGuard {
remoteServers.remove(Address(hostname, port))
}
}
/**
@ -203,13 +212,14 @@ class RemoteServer extends Logging with ListenerManagement {
address = Address(_hostname,_port)
log.info("Starting remote server at [%s:%s]", hostname, port)
RemoteServer.register(hostname, port, this)
val pipelineFactory = new RemoteServerPipelineFactory(
name, openChannels, loader, this)
val pipelineFactory = new RemoteServerPipelineFactory(name, openChannels, loader, this)
bootstrap.setPipelineFactory(pipelineFactory)
bootstrap.setOption("child.tcpNoDelay", true)
bootstrap.setOption("child.keepAlive", true)
bootstrap.setOption("child.reuseAddress", true)
bootstrap.setOption("child.connectTimeoutMillis", RemoteServer.CONNECTION_TIMEOUT_MILLIS.toMillis)
openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port)))
_isRunning = true
Cluster.registerLocalNode(hostname, port)
@ -251,11 +261,8 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized {
log.debug("Registering server side remote typed actor [%s] with id [%s]", typedActor.getClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid())
} else {
registerTypedActor(id, typedActor, typedActors())
}
if (id.startsWith(UUID_PREFIX)) registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid)
else registerTypedActor(id, typedActor, typedActors)
}
/**
@ -270,28 +277,19 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def register(id: String, actorRef: ActorRef): Unit = synchronized {
log.debug("Registering server side remote actor [%s] with id [%s]", actorRef.actorClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid())
} else {
register(id, actorRef, actors())
}
if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid)
else register(id, actorRef, actors)
}
private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
if (_isRunning && !registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
}
private def registerTypedActor[Key](id: Key, typedActor: AnyRef, registry: ConcurrentHashMap[Key, AnyRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
registry.put(id, typedActor)
}
}
if (_isRunning && !registry.contains(id)) registry.put(id, typedActor)
}
/**
@ -300,8 +298,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(actorRef: ActorRef):Unit = synchronized {
if (_isRunning) {
log.debug("Unregistering server side remote actor [%s] with id [%s:%s]", actorRef.actorClass.getName, actorRef.id, actorRef.uuid)
actors().remove(actorRef.id,actorRef)
actorsByUuid().remove(actorRef.uuid,actorRef)
actors.remove(actorRef.id, actorRef)
actorsByUuid.remove(actorRef.uuid, actorRef)
}
}
@ -313,12 +311,11 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
actorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
val actorRef = actors() get id
actorsByUuid().remove(actorRef.uuid,actorRef)
actors().remove(id,actorRef)
if (id.startsWith(UUID_PREFIX)) actorsByUuid.remove(id.substring(UUID_PREFIX.length))
else {
val actorRef = actors get id
actorsByUuid.remove(actorRef.uuid, actorRef)
actors.remove(id,actorRef)
}
}
}
@ -331,11 +328,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregisterTypedActor(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote typed actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
typedActorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
typedActors().remove(id)
}
if (id.startsWith(UUID_PREFIX)) typedActorsByUuid.remove(id.substring(UUID_PREFIX.length))
else typedActors.remove(id)
}
}
@ -343,10 +337,10 @@ class RemoteServer extends Logging with ListenerManagement {
protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
private[akka] def actors() = ActorRegistry.actors(address)
private[akka] def actorsByUuid() = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors() = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid() = ActorRegistry.typedActorsByUuid(address)
private[akka] def actors = ActorRegistry.actors(address)
private[akka] def actorsByUuid = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address)
}
object RemoteServerSslContext {
@ -389,7 +383,7 @@ class RemoteServerPipelineFactory(
val lenPrep = new LengthFieldPrepender(4)
val protobufDec = new ProtobufDecoder(RemoteRequestProtocol.getDefaultInstance)
val protobufEnc = new ProtobufEncoder
val (enc,dec) = RemoteServer.COMPRESSION_SCHEME match {
val (enc, dec) = RemoteServer.COMPRESSION_SCHEME match {
case "zlib" => (join(new ZlibEncoder(RemoteServer.ZLIB_COMPRESSION_LEVEL)), join(new ZlibDecoder))
case _ => (join(), join())
}
@ -410,7 +404,9 @@ class RemoteServerHandler(
val applicationLoader: Option[ClassLoader],
val server: RemoteServer) extends SimpleChannelUpstreamHandler with Logging {
import RemoteServer._
val AW_PROXY_PREFIX = "$$ProxiedByAW".intern
val CHANNEL_INIT = "channel-init".intern
applicationLoader.foreach(MessageSerializer.setClassLoader(_))
@ -434,9 +430,8 @@ class RemoteServerHandler(
} else future.getChannel.close
}
})
} else {
server.notifyListeners(RemoteServerClientConnected(server))
}
} else server.notifyListeners(RemoteServerClientConnected(server))
if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
@ -445,8 +440,7 @@ class RemoteServerHandler(
}
override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = {
if (event.isInstanceOf[ChannelStateEvent] &&
event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) {
if (event.isInstanceOf[ChannelStateEvent] && event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) {
log.debug(event.toString)
}
super.handleUpstream(ctx, event)
@ -456,7 +450,9 @@ class RemoteServerHandler(
val message = event.getMessage
if (message eq null) throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event)
if (message.isInstanceOf[RemoteRequestProtocol]) {
handleRemoteRequestProtocol(message.asInstanceOf[RemoteRequestProtocol], event.getChannel)
val requestProtocol = message.asInstanceOf[RemoteRequestProtocol]
if (RemoteServer.REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx)
handleRemoteRequestProtocol(requestProtocol, event.getChannel)
}
}
@ -491,8 +487,11 @@ class RemoteServerHandler(
case RemoteActorSystemMessage.Stop => actorRef.stop
case _ => // then match on user defined messages
if (request.getIsOneWay) actorRef.!(message)(sender)
else actorRef.postMessageToMailboxAndCreateFutureResultWithTimeout(message,request.getActorInfo.getTimeout,None,Some(
new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout){
else actorRef.postMessageToMailboxAndCreateFutureResultWithTimeout(
message,
request.getActorInfo.getTimeout,
None,
Some(new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout){
override def onComplete(result: AnyRef) {
log.debug("Returning result from actor invocation [%s]", result)
val replyBuilder = RemoteReplyProtocol.newBuilder
@ -506,8 +505,7 @@ class RemoteServerHandler(
try {
channel.write(replyBuilder.build)
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
@ -515,8 +513,7 @@ class RemoteServerHandler(
try {
channel.write(createErrorReplyMessage(exception, request, true))
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
}
@ -528,8 +525,8 @@ class RemoteServerHandler(
val actorInfo = request.getActorInfo
val typedActorInfo = actorInfo.getTypedActorInfo
log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface)
val typedActor = createTypedActor(actorInfo)
val typedActor = createTypedActor(actorInfo)
val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList
val argClasses = args.map(_.getClass)
@ -551,49 +548,39 @@ class RemoteServerHandler(
case e: InvocationTargetException =>
channel.write(createErrorReplyMessage(e.getCause, request, false))
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable =>
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, false))
server.notifyListeners(RemoteServerError(e, server))
}
}
private def findActorById(id: String) : ActorRef = {
server.actors().get(id)
server.actors.get(id)
}
private def findActorByUuid(uuid: String) : ActorRef = {
server.actorsByUuid().get(uuid)
server.actorsByUuid.get(uuid)
}
private def findTypedActorById(id: String) : AnyRef = {
server.typedActors().get(id)
server.typedActors.get(id)
}
private def findTypedActorByUuid(uuid: String) : AnyRef = {
server.typedActorsByUuid().get(uuid)
server.typedActorsByUuid.get(uuid)
}
private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findActorByUuid(id.substring(UUID_PREFIX.length))
else findActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findActorByUuid(uuid)
actorRefOrNull
}
private def findTypedActorByIdOrUuid(id: String, uuid: String) : AnyRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findTypedActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findTypedActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findTypedActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findTypedActorByUuid(id.substring(UUID_PREFIX.length))
else findTypedActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findTypedActorByUuid(uuid)
actorRefOrNull
}
@ -677,4 +664,19 @@ class RemoteServerHandler(
if (request.hasSupervisorUuid) replyBuilder.setSupervisorUuid(request.getSupervisorUuid)
replyBuilder.build
}
private def authenticateRemoteClient(request: RemoteRequestProtocol, ctx: ChannelHandlerContext) = {
val attachment = ctx.getAttachment
if ((attachment ne null) &&
attachment.isInstanceOf[String] &&
attachment.asInstanceOf[String] == CHANNEL_INIT) { // is first time around, channel initialization
ctx.setAttachment(null)
val clientAddress = ctx.getChannel.getRemoteAddress.toString
if (!request.hasCookie) throw new SecurityException(
"The remote client [" + clientAddress + "] does not have a secure cookie.")
if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException(
"The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie")
log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress)
}
}
}