Completed Erlang-style cookie handshake between RemoteClient and RemoteServer

This commit is contained in:
Jonas Bonér 2010-10-26 15:23:50 +02:00
commit 52f5e5e861
9 changed files with 187 additions and 211 deletions

View file

@ -120,7 +120,7 @@ class Switch(startAsOn: Boolean = false) {
private val switch = new AtomicBoolean(startAsOn)
protected def transcend(from: Boolean,action: => Unit): Boolean = synchronized {
if (switch.compareAndSet(from,!from)) {
if (switch.compareAndSet(from, !from)) {
try {
action
} catch {
@ -133,43 +133,35 @@ class Switch(startAsOn: Boolean = false) {
}
def switchOff(action: => Unit): Boolean = transcend(from = true, action)
def switchOn(action: => Unit): Boolean = transcend(from = false,action)
def switchOn(action: => Unit): Boolean = transcend(from = false, action)
def switchOff: Boolean = synchronized { switch.compareAndSet(true,false) }
def switchOn: Boolean = synchronized { switch.compareAndSet(false,true) }
def switchOff: Boolean = synchronized { switch.compareAndSet(true, false) }
def switchOn: Boolean = synchronized { switch.compareAndSet(false, true) }
def ifOnYield[T](action: => T): Option[T] = {
if (switch.get)
Some(action)
else
None
if (switch.get) Some(action)
else None
}
def ifOffYield[T](action: => T): Option[T] = {
if (switch.get)
Some(action)
else
None
if (switch.get) Some(action)
else None
}
def ifOn(action: => Unit): Boolean = {
if (switch.get) {
action
true
}
else
false
} else false
}
def ifOff(action: => Unit): Boolean = {
if (!switch.get) {
action
true
}
else
false
} else false
}
def isOn = switch.get
def isOff = !isOn
}
}

View file

@ -91,6 +91,13 @@ message TypedActorInfoProtocol {
required string method = 2;
}
/**
* Defines a remote connection handshake.
*/
//message HandshakeProtocol {
// required string cookie = 1;
//}
/**
* Defines a remote message request.
*/

View file

@ -4,33 +4,33 @@
package se.scalablesolutions.akka.remote
import se.scalablesolutions.akka.remote.protocol.RemoteProtocol.{ActorType => ActorTypeProtocol, _}
import se.scalablesolutions.akka.actor.{Exit, Actor, ActorRef, ActorType, RemoteActorRef, IllegalActorStateException}
import se.scalablesolutions.akka.dispatch.{DefaultCompletableFuture, CompletableFuture}
import se.scalablesolutions.akka.actor.{Uuid,newUuid,uuidFrom}
import se.scalablesolutions.akka.remote.protocol.RemoteProtocol.{ ActorType => ActorTypeProtocol, _ }
import se.scalablesolutions.akka.actor._
//import se.scalablesolutions.akka.actor.Uuid.{newUuid, uuidFrom}
import se.scalablesolutions.akka.dispatch.{ DefaultCompletableFuture, CompletableFuture }
import se.scalablesolutions.akka.util._
import se.scalablesolutions.akka.config.Config._
import se.scalablesolutions.akka.serialization.RemoteActorSerialization._
import se.scalablesolutions.akka.AkkaException
import Actor._
import org.jboss.netty.channel._
import group.DefaultChannelGroup
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.bootstrap.ClientBootstrap
import org.jboss.netty.handler.codec.frame.{LengthFieldBasedFrameDecoder, LengthFieldPrepender}
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder}
import org.jboss.netty.handler.codec.protobuf.{ProtobufDecoder, ProtobufEncoder}
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.codec.compression.{ ZlibDecoder, ZlibEncoder }
import org.jboss.netty.handler.codec.protobuf.{ ProtobufDecoder, ProtobufEncoder }
import org.jboss.netty.handler.timeout.ReadTimeoutHandler
import org.jboss.netty.util.{TimerTask, Timeout, HashedWheelTimer}
import org.jboss.netty.util.{ TimerTask, Timeout, HashedWheelTimer }
import org.jboss.netty.handler.ssl.SslHandler
import java.net.{SocketAddress, InetSocketAddress}
import java.util.concurrent.{TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet}
import java.util.concurrent.atomic.AtomicLong
import java.net.{ SocketAddress, InetSocketAddress }
import java.util.concurrent.{ TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet }
import java.util.concurrent.atomic.{ AtomicLong, AtomicBoolean }
import scala.collection.mutable.{HashSet, HashMap}
import scala.collection.mutable.{ HashSet, HashMap }
import scala.reflect.BeanProperty
import se.scalablesolutions.akka.actor._
import se.scalablesolutions.akka.util._
/**
* Life-cycle events for RemoteClient.
@ -51,7 +51,7 @@ case class RemoteClientShutdown(
/**
* Thrown for example when trying to send a message using a RemoteClient that is either not started or shut down.
*/
class RemoteClientException private[akka](message: String, @BeanProperty val client: RemoteClient) extends AkkaException(message)
class RemoteClientException private[akka] (message: String, @BeanProperty val client: RemoteClient) extends AkkaException(message)
/**
* The RemoteClient object manages RemoteClient instances and gives you an API to lookup remote actor handles.
@ -59,17 +59,18 @@ class RemoteClientException private[akka](message: String, @BeanProperty val cli
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
object RemoteClient extends Logging {
val SECURE_COOKIE: Option[String] = {
val SECURE_COOKIE: Option[String] = {
val cookie = config.getString("akka.remote.secure-cookie", "")
if (cookie == "") None
else Some(cookie)
}
val READ_TIMEOUT = Duration(config.getInt("akka.remote.client.read-timeout", 1), TIME_UNIT)
val READ_TIMEOUT = Duration(config.getInt("akka.remote.client.read-timeout", 1), TIME_UNIT)
val RECONNECT_DELAY = Duration(config.getInt("akka.remote.client.reconnect-delay", 5), TIME_UNIT)
private val remoteClients = new HashMap[String, RemoteClient]
private val remoteActors = new HashMap[Address, HashSet[Uuid]]
private val remoteActors = new HashMap[Address, HashSet[Uuid]]
def actorFor(classNameOrServiceId: String, hostname: String, port: Int): ActorRef =
actorFor(classNameOrServiceId, classNameOrServiceId, 5000L, hostname, port, None)
@ -92,23 +93,23 @@ object RemoteClient extends Logging {
def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int): ActorRef =
RemoteActorRef(serviceId, className, hostname, port, timeout, None)
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int) : T = {
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int): T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, 5000L, hostname, port, None)
}
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int) : T = {
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int): T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, None)
}
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, Some(loader))
}
def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): T = {
typedActorFor(intfClass, serviceId, implClassName, timeout, hostname, port, Some(loader))
}
private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]) : T = {
private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]): T = {
val actorRef = RemoteActorRef(serviceId, implClassName, hostname, port, timeout, loader, ActorType.TypedActor)
TypedActor.createProxyForRemoteActorRef(intfClass, actorRef)
}
@ -164,7 +165,7 @@ object RemoteClient extends Logging {
* Clean-up all open connections.
*/
def shutdownAll = synchronized {
remoteClients.foreach({case (addr, client) => client.shutdown})
remoteClients.foreach({ case (addr, client) => client.shutdown })
remoteClients.clear
}
@ -206,34 +207,40 @@ class RemoteClient private[akka] (
private val remoteAddress = new InetSocketAddress(hostname, port)
//FIXME rewrite to a wrapper object (minimize volatile access and maximize encapsulation)
@volatile private var bootstrap: ClientBootstrap = _
@volatile private[remote] var connection: ChannelFuture = _
@volatile private[remote] var openChannels: DefaultChannelGroup = _
@volatile private var timer: HashedWheelTimer = _
@volatile
private var bootstrap: ClientBootstrap = _
@volatile
private[remote] var connection: ChannelFuture = _
@volatile
private[remote] var openChannels: DefaultChannelGroup = _
@volatile
private var timer: HashedWheelTimer = _
private[remote] val runSwitch = new Switch()
private[remote] val isAuthenticated = new AtomicBoolean(false)
private[remote] def isRunning = runSwitch.isOn
private val reconnectionTimeWindow = Duration(config.getInt(
"akka.remote.client.reconnection-time-window", 600), TIME_UNIT).toMillis
@volatile private var reconnectionTimeWindowStart = 0L
@volatile
private var reconnectionTimeWindowStart = 0L
def connect = runSwitch switchOn {
openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName)
timer = new HashedWheelTimer
bootstrap = new ClientBootstrap(
new NioClientSocketChannelFactory(
Executors.newCachedThreadPool,Executors.newCachedThreadPool
)
)
bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool))
bootstrap.setPipelineFactory(new RemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this))
bootstrap.setOption("tcpNoDelay", true)
bootstrap.setOption("keepAlive", true)
connection = bootstrap.connect(remoteAddress)
log.info("Starting remote client connection to [%s:%s]", hostname, port)
// Wait until the connection attempt succeeds or fails.
connection = bootstrap.connect(remoteAddress)
val channel = connection.awaitUninterruptibly.getChannel
openChannels.add(channel)
if (!connection.isSuccess) {
notifyListeners(RemoteClientError(connection.getCause, this))
log.error(connection.getCause, "Remote client connection to [%s:%s] has failed", hostname, port)
@ -274,31 +281,34 @@ class RemoteClient private[akka] (
actorRef: ActorRef,
typedActorInfo: Option[Tuple2[String, String]],
actorType: ActorType): Option[CompletableFuture[T]] = {
val cookie = if (isAuthenticated.compareAndSet(false, true)) RemoteClient.SECURE_COOKIE
else None
send(createRemoteRequestProtocolBuilder(
actorRef, message, isOneWay, senderOption, typedActorInfo, actorType, RemoteClient.SECURE_COOKIE).build, senderFuture)
}
actorRef, message, isOneWay, senderOption, typedActorInfo, actorType, cookie).build, senderFuture)
}
def send[T](
request: RemoteRequestProtocol,
senderFuture: Option[CompletableFuture[T]]):
Option[CompletableFuture[T]] = if (isRunning) {
if (request.getIsOneWay) {
connection.getChannel.write(request)
None
} else {
futures.synchronized {
val futureResult = if (senderFuture.isDefined) senderFuture.get
else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout)
futures.put(uuidFrom(request.getUuid.getHigh,request.getUuid.getLow), futureResult)
senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = {
if (isRunning) {
if (request.getIsOneWay) {
connection.getChannel.write(request)
Some(futureResult)
None
} else {
futures.synchronized {
val futureResult = if (senderFuture.isDefined) senderFuture.get
else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout)
futures.put(uuidFrom(request.getUuid.getHigh, request.getUuid.getLow), futureResult)
connection.getChannel.write(request)
Some(futureResult)
}
}
} else {
val exception = new RemoteClientException(
"Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this)
notifyListeners(RemoteClientError(exception, this))
throw exception
}
} else {
val exception = new RemoteClientException(
"Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this)
notifyListeners(RemoteClientError(exception, this))
throw exception
}
private[akka] def registerSupervisorForActor(actorRef: ActorRef) =
@ -331,13 +341,13 @@ class RemoteClient private[akka] (
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
class RemoteClientPipelineFactory(
name: String,
futures: ConcurrentMap[Uuid, CompletableFuture[_]],
supervisors: ConcurrentMap[Uuid, ActorRef],
bootstrap: ClientBootstrap,
remoteAddress: SocketAddress,
timer: HashedWheelTimer,
client: RemoteClient) extends ChannelPipelineFactory {
name: String,
futures: ConcurrentMap[Uuid, CompletableFuture[_]],
supervisors: ConcurrentMap[Uuid, ActorRef],
bootstrap: ClientBootstrap,
remoteAddress: SocketAddress,
timer: HashedWheelTimer,
client: RemoteClient) extends ChannelPipelineFactory {
def getPipeline: ChannelPipeline = {
def join(ch: ChannelHandler*) = Array[ChannelHandler](ch: _*)
@ -349,15 +359,15 @@ class RemoteClientPipelineFactory(
e
}
val ssl = if (RemoteServer.SECURE) join(new SslHandler(engine)) else join()
val timeout = new ReadTimeoutHandler(timer, RemoteClient.READ_TIMEOUT.toMillis.toInt)
val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4)
val lenPrep = new LengthFieldPrepender(4)
val ssl = if (RemoteServer.SECURE) join(new SslHandler(engine)) else join()
val timeout = new ReadTimeoutHandler(timer, RemoteClient.READ_TIMEOUT.toMillis.toInt)
val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4)
val lenPrep = new LengthFieldPrepender(4)
val protobufDec = new ProtobufDecoder(RemoteReplyProtocol.getDefaultInstance)
val protobufEnc = new ProtobufEncoder
val (enc, dec) = RemoteServer.COMPRESSION_SCHEME match {
val (enc, dec) = RemoteServer.COMPRESSION_SCHEME match {
case "zlib" => (join(new ZlibEncoder(RemoteServer.ZLIB_COMPRESSION_LEVEL)), join(new ZlibDecoder))
case _ => (join(), join())
case _ => (join(), join())
}
val remoteClient = new RemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client)
@ -371,18 +381,18 @@ class RemoteClientPipelineFactory(
*/
@ChannelHandler.Sharable
class RemoteClientHandler(
val name: String,
val futures: ConcurrentMap[Uuid, CompletableFuture[_]],
val supervisors: ConcurrentMap[Uuid, ActorRef],
val bootstrap: ClientBootstrap,
val remoteAddress: SocketAddress,
val timer: HashedWheelTimer,
val client: RemoteClient)
extends SimpleChannelUpstreamHandler with Logging {
val name: String,
val futures: ConcurrentMap[Uuid, CompletableFuture[_]],
val supervisors: ConcurrentMap[Uuid, ActorRef],
val bootstrap: ClientBootstrap,
val remoteAddress: SocketAddress,
val timer: HashedWheelTimer,
val client: RemoteClient)
extends SimpleChannelUpstreamHandler with Logging {
override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = {
if (event.isInstanceOf[ChannelStateEvent] &&
event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) {
event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) {
log.debug(event.toString)
}
super.handleUpstream(ctx, event)
@ -393,7 +403,7 @@ class RemoteClientHandler(
val result = event.getMessage
if (result.isInstanceOf[RemoteReplyProtocol]) {
val reply = result.asInstanceOf[RemoteReplyProtocol]
val replyUuid = uuidFrom(reply.getUuid.getHigh,reply.getUuid.getLow)
val replyUuid = uuidFrom(reply.getUuid.getHigh, reply.getUuid.getLow)
log.debug("Remote client received RemoteReplyProtocol[\n%s]", reply.toString)
val future = futures.get(replyUuid).asInstanceOf[CompletableFuture[Any]]
if (reply.getIsSuccessful) {
@ -401,7 +411,7 @@ class RemoteClientHandler(
future.completeWithResult(message)
} else {
if (reply.hasSupervisorUuid()) {
val supervisorUuid = uuidFrom(reply.getSupervisorUuid.getHigh,reply.getSupervisorUuid.getLow)
val supervisorUuid = uuidFrom(reply.getSupervisorUuid.getHigh, reply.getSupervisorUuid.getLow)
if (!supervisors.containsKey(supervisorUuid)) throw new IllegalActorStateException(
"Expected a registered supervisor for UUID [" + supervisorUuid + "] but none was found")
val supervisedActor = supervisors.get(supervisorUuid)
@ -430,6 +440,7 @@ class RemoteClientHandler(
timer.newTimeout(new TimerTask() {
def run(timeout: Timeout) = {
client.openChannels.remove(event.getChannel)
client.isAuthenticated.set(false)
log.debug("Remote client reconnecting to [%s]", remoteAddress)
client.connection = bootstrap.connect(remoteAddress)
client.connection.awaitUninterruptibly // Wait until the connection attempt succeeds or fails.
@ -475,9 +486,9 @@ class RemoteClientHandler(
val exception = reply.getException
val classname = exception.getClassname
val exceptionClass = if (loader.isDefined) loader.get.loadClass(classname)
else Class.forName(classname)
else Class.forName(classname)
exceptionClass
.getConstructor(Array[Class[_]](classOf[String]): _*)
.newInstance(exception.getMessage).asInstanceOf[Throwable]
.getConstructor(Array[Class[_]](classOf[String]): _*)
.newInstance(exception.getMessage).asInstanceOf[Throwable]
}
}

View file

@ -212,13 +212,14 @@ class RemoteServer extends Logging with ListenerManagement {
address = Address(_hostname,_port)
log.info("Starting remote server at [%s:%s]", hostname, port)
RemoteServer.register(hostname, port, this)
val pipelineFactory = new RemoteServerPipelineFactory(
name, openChannels, loader, this)
val pipelineFactory = new RemoteServerPipelineFactory(name, openChannels, loader, this)
bootstrap.setPipelineFactory(pipelineFactory)
bootstrap.setOption("child.tcpNoDelay", true)
bootstrap.setOption("child.keepAlive", true)
bootstrap.setOption("child.reuseAddress", true)
bootstrap.setOption("child.connectTimeoutMillis", RemoteServer.CONNECTION_TIMEOUT_MILLIS.toMillis)
openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port)))
_isRunning = true
Cluster.registerLocalNode(hostname, port)
@ -260,11 +261,8 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized {
log.debug("Registering server side remote typed actor [%s] with id [%s]", typedActor.getClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid())
} else {
registerTypedActor(id, typedActor, typedActors())
}
if (id.startsWith(UUID_PREFIX)) registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid)
else registerTypedActor(id, typedActor, typedActors)
}
/**
@ -279,28 +277,19 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def register(id: String, actorRef: ActorRef): Unit = synchronized {
log.debug("Registering server side remote actor [%s] with id [%s]", actorRef.actorClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid())
} else {
register(id, actorRef, actors())
}
if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid)
else register(id, actorRef, actors)
}
private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
if (_isRunning && !registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
}
private def registerTypedActor[Key](id: Key, typedActor: AnyRef, registry: ConcurrentHashMap[Key, AnyRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
registry.put(id, typedActor)
}
}
if (_isRunning && !registry.contains(id)) registry.put(id, typedActor)
}
/**
@ -309,8 +298,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(actorRef: ActorRef):Unit = synchronized {
if (_isRunning) {
log.debug("Unregistering server side remote actor [%s] with id [%s:%s]", actorRef.actorClass.getName, actorRef.id, actorRef.uuid)
actors().remove(actorRef.id,actorRef)
actorsByUuid().remove(actorRef.uuid,actorRef)
actors.remove(actorRef.id, actorRef)
actorsByUuid.remove(actorRef.uuid, actorRef)
}
}
@ -322,12 +311,11 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
actorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
val actorRef = actors() get id
actorsByUuid().remove(actorRef.uuid,actorRef)
actors().remove(id,actorRef)
if (id.startsWith(UUID_PREFIX)) actorsByUuid.remove(id.substring(UUID_PREFIX.length))
else {
val actorRef = actors get id
actorsByUuid.remove(actorRef.uuid, actorRef)
actors.remove(id,actorRef)
}
}
}
@ -340,11 +328,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregisterTypedActor(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote typed actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
typedActorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
typedActors().remove(id)
}
if (id.startsWith(UUID_PREFIX)) typedActorsByUuid.remove(id.substring(UUID_PREFIX.length))
else typedActors.remove(id)
}
}
@ -352,10 +337,10 @@ class RemoteServer extends Logging with ListenerManagement {
protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
private[akka] def actors() = ActorRegistry.actors(address)
private[akka] def actorsByUuid() = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors() = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid() = ActorRegistry.typedActorsByUuid(address)
private[akka] def actors = ActorRegistry.actors(address)
private[akka] def actorsByUuid = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address)
}
object RemoteServerSslContext {
@ -419,6 +404,7 @@ class RemoteServerHandler(
val applicationLoader: Option[ClassLoader],
val server: RemoteServer) extends SimpleChannelUpstreamHandler with Logging {
import RemoteServer._
val AW_PROXY_PREFIX = "$$ProxiedByAW".intern
val CHANNEL_INIT = "channel-init".intern
@ -444,10 +430,8 @@ class RemoteServerHandler(
} else future.getChannel.close
}
})
} else {
server.notifyListeners(RemoteServerClientConnected(server))
}
if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT)
} else server.notifyListeners(RemoteServerClientConnected(server))
if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
@ -467,7 +451,7 @@ class RemoteServerHandler(
if (message eq null) throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event)
if (message.isInstanceOf[RemoteRequestProtocol]) {
val requestProtocol = message.asInstanceOf[RemoteRequestProtocol]
authenticateRemoteClient(requestProtocol, ctx)
if (RemoteServer.REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx)
handleRemoteRequestProtocol(requestProtocol, event.getChannel)
}
}
@ -521,8 +505,7 @@ class RemoteServerHandler(
try {
channel.write(replyBuilder.build)
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
@ -530,8 +513,7 @@ class RemoteServerHandler(
try {
channel.write(createErrorReplyMessage(exception, request, true))
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
}
@ -543,8 +525,8 @@ class RemoteServerHandler(
val actorInfo = request.getActorInfo
val typedActorInfo = actorInfo.getTypedActorInfo
log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface)
val typedActor = createTypedActor(actorInfo)
val typedActor = createTypedActor(actorInfo)
val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList
val argClasses = args.map(_.getClass)
@ -566,49 +548,39 @@ class RemoteServerHandler(
case e: InvocationTargetException =>
channel.write(createErrorReplyMessage(e.getCause, request, false))
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable =>
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, false))
server.notifyListeners(RemoteServerError(e, server))
}
}
private def findActorById(id: String) : ActorRef = {
server.actors().get(id)
server.actors.get(id)
}
private def findActorByUuid(uuid: String) : ActorRef = {
server.actorsByUuid().get(uuid)
server.actorsByUuid.get(uuid)
}
private def findTypedActorById(id: String) : AnyRef = {
server.typedActors().get(id)
server.typedActors.get(id)
}
private def findTypedActorByUuid(uuid: String) : AnyRef = {
server.typedActorsByUuid().get(uuid)
server.typedActorsByUuid.get(uuid)
}
private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findActorByUuid(id.substring(UUID_PREFIX.length))
else findActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findActorByUuid(uuid)
actorRefOrNull
}
private def findTypedActorByIdOrUuid(id: String, uuid: String) : AnyRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findTypedActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findTypedActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findTypedActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findTypedActorByUuid(id.substring(UUID_PREFIX.length))
else findTypedActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findTypedActorByUuid(uuid)
actorRefOrNull
}
@ -694,18 +666,17 @@ class RemoteServerHandler(
}
private def authenticateRemoteClient(request: RemoteRequestProtocol, ctx: ChannelHandlerContext) = {
if (RemoteServer.REQUIRE_COOKIE) {
val attachment = ctx.getAttachment
if ((attachment ne null) &&
attachment.isInstanceOf[String] &&
attachment.asInstanceOf[String] == CHANNEL_INIT) {
val clientAddress = ctx.getChannel.getRemoteAddress.toString
if (!request.hasCookie) throw new SecurityException(
"The remote client [" + clientAddress + "] does not have a secure cookie.")
if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException(
"The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie")
log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress)
}
val attachment = ctx.getAttachment
if ((attachment ne null) &&
attachment.isInstanceOf[String] &&
attachment.asInstanceOf[String] == CHANNEL_INIT) { // is first time around, channel initialization
ctx.setAttachment(null)
val clientAddress = ctx.getChannel.getRemoteAddress.toString
if (!request.hasCookie) throw new SecurityException(
"The remote client [" + clientAddress + "] does not have a secure cookie.")
if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException(
"The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie")
log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress)
}
}
}

View file

@ -249,11 +249,11 @@ object RemoteActorSerialization {
ActorRegistry.registerActorByUuid(homeAddress, uuid.toString, ar)
RemoteActorRefProtocol.newBuilder
.setClassOrServiceName(uuid.toString)
.setActorClassname(actorClassName)
.setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build)
.setTimeout(timeout)
.build
.setClassOrServiceName(uuid.toString)
.setActorClassname(actorClassName)
.setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build)
.setTimeout(timeout)
.build
}
def createRemoteRequestProtocolBuilder(
@ -263,8 +263,7 @@ object RemoteActorSerialization {
senderOption: Option[ActorRef],
typedActorInfo: Option[Tuple2[String, String]],
actorType: ActorType,
secureCookie: Option[String]):
RemoteRequestProtocol.Builder = {
secureCookie: Option[String]): RemoteRequestProtocol.Builder = {
import actorRef._
val actorInfoBuilder = ActorInfoProtocol.newBuilder
@ -273,13 +272,12 @@ object RemoteActorSerialization {
.setTarget(actorClassName)
.setTimeout(timeout)
typedActorInfo.foreach {
typedActor =>
actorInfoBuilder.setTypedActorInfo(
TypedActorInfoProtocol.newBuilder
.setInterface(typedActor._1)
.setMethod(typedActor._2)
.build)
typedActorInfo.foreach { typedActor =>
actorInfoBuilder.setTypedActorInfo(
TypedActorInfoProtocol.newBuilder
.setInterface(typedActor._1)
.setMethod(typedActor._2)
.build)
}
actorType match {
@ -310,8 +308,6 @@ object RemoteActorSerialization {
}
requestBuilder
}
}
@ -408,5 +404,4 @@ object RemoteTypedActorSerialization {
.setInterfaceName(init.interfaceClass.getName)
.build
}
}

View file

@ -201,18 +201,18 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite {
def shouldRegisterAndUnregister {
val actor1 = actorOf[RemoteActorSpecActorUnidirectional]
server.register("my-service-1", actor1)
assert(server.actors().get("my-service-1") ne null, "actor registered")
assert(server.actors.get("my-service-1") ne null, "actor registered")
server.unregister("my-service-1")
assert(server.actors().get("my-service-1") eq null, "actor unregistered")
assert(server.actors.get("my-service-1") eq null, "actor unregistered")
}
@Test
def shouldRegisterAndUnregisterByUuid {
val actor1 = actorOf[RemoteActorSpecActorUnidirectional]
server.register("uuid:" + actor1.uuid, actor1)
assert(server.actorsByUuid().get(actor1.uuid.toString) ne null, "actor registered")
assert(server.actorsByUuid.get(actor1.uuid.toString) ne null, "actor registered")
server.unregister("uuid:" + actor1.uuid)
assert(server.actorsByUuid().get(actor1.uuid) eq null, "actor unregistered")
assert(server.actorsByUuid.get(actor1.uuid) eq null, "actor unregistered")
}
}

View file

@ -103,9 +103,9 @@ class ServerInitiatedRemoteTypedActorSpec extends
it("should register and unregister typed actors") {
val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
server.registerTypedActor("my-test-service", typedActor)
assert(server.typedActors().get("my-test-service") ne null, "typed actor registered")
assert(server.typedActors.get("my-test-service") ne null, "typed actor registered")
server.unregisterTypedActor("my-test-service")
assert(server.typedActors().get("my-test-service") eq null, "typed actor unregistered")
assert(server.typedActors.get("my-test-service") eq null, "typed actor unregistered")
}
it("should register and unregister typed actors by uuid") {
@ -113,9 +113,9 @@ class ServerInitiatedRemoteTypedActorSpec extends
val init = AspectInitRegistry.initFor(typedActor)
val uuid = "uuid:" + init.actorRef.uuid
server.registerTypedActor(uuid, typedActor)
assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) ne null, "typed actor registered")
assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) ne null, "typed actor registered")
server.unregisterTypedActor(uuid)
assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) eq null, "typed actor unregistered")
assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) eq null, "typed actor unregistered")
}
it("should find typed actors by uuid") {
@ -123,7 +123,7 @@ class ServerInitiatedRemoteTypedActorSpec extends
val init = AspectInitRegistry.initFor(typedActor)
val uuid = "uuid:" + init.actorRef.uuid
server.registerTypedActor(uuid, typedActor)
assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) ne null, "typed actor registered")
assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) ne null, "typed actor registered")
val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], uuid, HOSTNAME, PORT)
expect("oneway") {

View file

@ -4,7 +4,7 @@ object AkkaRepositories {
val AkkaRepo = MavenRepository("Akka Repository", "http://scalablesolutions.se/akka/repository")
val CodehausRepo = MavenRepository("Codehaus Repo", "http://repository.codehaus.org")
val GuiceyFruitRepo = MavenRepository("GuiceyFruit Repo", "http://guiceyfruit.googlecode.com/svn/repo/releases/")
val JBossRepo = MavenRepository("JBoss Repo", "https://repository.jboss.org/nexus/content/groups/public/")
val JBossRepo = MavenRepository("JBoss Repo", "http://repository.jboss.org/nexus/content/groups/public/")
val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2")
val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases")
val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo")

View file

@ -72,7 +72,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val EmbeddedRepo = MavenRepository("Embedded Repo", (info.projectPath / "embedded-repo").asURL.toString)
lazy val FusesourceSnapshotRepo = MavenRepository("Fusesource Snapshots", "http://repo.fusesource.com/nexus/content/repositories/snapshots")
lazy val GuiceyFruitRepo = MavenRepository("GuiceyFruit Repo", "http://guiceyfruit.googlecode.com/svn/repo/releases/")
lazy val JBossRepo = MavenRepository("JBoss Repo", "https://repository.jboss.org/nexus/content/groups/public/")
lazy val JBossRepo = MavenRepository("JBoss Repo", "http://repository.jboss.org/nexus/content/groups/public/")
lazy val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2")
lazy val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases")
lazy val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo")