diff --git a/akka-actor/src/main/scala/util/LockUtil.scala b/akka-actor/src/main/scala/util/LockUtil.scala index 909713194b..d0a318aa02 100644 --- a/akka-actor/src/main/scala/util/LockUtil.scala +++ b/akka-actor/src/main/scala/util/LockUtil.scala @@ -120,7 +120,7 @@ class Switch(startAsOn: Boolean = false) { private val switch = new AtomicBoolean(startAsOn) protected def transcend(from: Boolean,action: => Unit): Boolean = synchronized { - if (switch.compareAndSet(from,!from)) { + if (switch.compareAndSet(from, !from)) { try { action } catch { @@ -133,43 +133,35 @@ class Switch(startAsOn: Boolean = false) { } def switchOff(action: => Unit): Boolean = transcend(from = true, action) - def switchOn(action: => Unit): Boolean = transcend(from = false,action) + def switchOn(action: => Unit): Boolean = transcend(from = false, action) - def switchOff: Boolean = synchronized { switch.compareAndSet(true,false) } - def switchOn: Boolean = synchronized { switch.compareAndSet(false,true) } + def switchOff: Boolean = synchronized { switch.compareAndSet(true, false) } + def switchOn: Boolean = synchronized { switch.compareAndSet(false, true) } def ifOnYield[T](action: => T): Option[T] = { - if (switch.get) - Some(action) - else - None + if (switch.get) Some(action) + else None } def ifOffYield[T](action: => T): Option[T] = { - if (switch.get) - Some(action) - else - None + if (switch.get) Some(action) + else None } def ifOn(action: => Unit): Boolean = { if (switch.get) { action true - } - else - false + } else false } def ifOff(action: => Unit): Boolean = { if (!switch.get) { action true - } - else - false + } else false } def isOn = switch.get def isOff = !isOn -} \ No newline at end of file +} diff --git a/akka-remote/src/main/protocol/RemoteProtocol.proto b/akka-remote/src/main/protocol/RemoteProtocol.proto index ce694141a0..46e38fd158 100644 --- a/akka-remote/src/main/protocol/RemoteProtocol.proto +++ b/akka-remote/src/main/protocol/RemoteProtocol.proto @@ -91,6 +91,13 @@ message TypedActorInfoProtocol { required string method = 2; } +/** + * Defines a remote connection handshake. + */ +//message HandshakeProtocol { +// required string cookie = 1; +//} + /** * Defines a remote message request. */ diff --git a/akka-remote/src/main/scala/remote/RemoteClient.scala b/akka-remote/src/main/scala/remote/RemoteClient.scala index 1ddf57869e..3ecb64f77c 100644 --- a/akka-remote/src/main/scala/remote/RemoteClient.scala +++ b/akka-remote/src/main/scala/remote/RemoteClient.scala @@ -4,33 +4,33 @@ package se.scalablesolutions.akka.remote -import se.scalablesolutions.akka.remote.protocol.RemoteProtocol.{ActorType => ActorTypeProtocol, _} -import se.scalablesolutions.akka.actor.{Exit, Actor, ActorRef, ActorType, RemoteActorRef, IllegalActorStateException} -import se.scalablesolutions.akka.dispatch.{DefaultCompletableFuture, CompletableFuture} -import se.scalablesolutions.akka.actor.{Uuid,newUuid,uuidFrom} +import se.scalablesolutions.akka.remote.protocol.RemoteProtocol.{ ActorType => ActorTypeProtocol, _ } +import se.scalablesolutions.akka.actor._ +//import se.scalablesolutions.akka.actor.Uuid.{newUuid, uuidFrom} +import se.scalablesolutions.akka.dispatch.{ DefaultCompletableFuture, CompletableFuture } +import se.scalablesolutions.akka.util._ import se.scalablesolutions.akka.config.Config._ import se.scalablesolutions.akka.serialization.RemoteActorSerialization._ import se.scalablesolutions.akka.AkkaException import Actor._ + import org.jboss.netty.channel._ import group.DefaultChannelGroup import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.bootstrap.ClientBootstrap -import org.jboss.netty.handler.codec.frame.{LengthFieldBasedFrameDecoder, LengthFieldPrepender} -import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} -import org.jboss.netty.handler.codec.protobuf.{ProtobufDecoder, ProtobufEncoder} +import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender } +import org.jboss.netty.handler.codec.compression.{ ZlibDecoder, ZlibEncoder } +import org.jboss.netty.handler.codec.protobuf.{ ProtobufDecoder, ProtobufEncoder } import org.jboss.netty.handler.timeout.ReadTimeoutHandler -import org.jboss.netty.util.{TimerTask, Timeout, HashedWheelTimer} +import org.jboss.netty.util.{ TimerTask, Timeout, HashedWheelTimer } import org.jboss.netty.handler.ssl.SslHandler -import java.net.{SocketAddress, InetSocketAddress} -import java.util.concurrent.{TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet} -import java.util.concurrent.atomic.AtomicLong +import java.net.{ SocketAddress, InetSocketAddress } +import java.util.concurrent.{ TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet } +import java.util.concurrent.atomic.{ AtomicLong, AtomicBoolean } -import scala.collection.mutable.{HashSet, HashMap} +import scala.collection.mutable.{ HashSet, HashMap } import scala.reflect.BeanProperty -import se.scalablesolutions.akka.actor._ -import se.scalablesolutions.akka.util._ /** * Life-cycle events for RemoteClient. @@ -51,7 +51,7 @@ case class RemoteClientShutdown( /** * Thrown for example when trying to send a message using a RemoteClient that is either not started or shut down. */ -class RemoteClientException private[akka](message: String, @BeanProperty val client: RemoteClient) extends AkkaException(message) +class RemoteClientException private[akka] (message: String, @BeanProperty val client: RemoteClient) extends AkkaException(message) /** * The RemoteClient object manages RemoteClient instances and gives you an API to lookup remote actor handles. @@ -59,17 +59,18 @@ class RemoteClientException private[akka](message: String, @BeanProperty val cli * @author Jonas Bonér */ object RemoteClient extends Logging { - val SECURE_COOKIE: Option[String] = { + + val SECURE_COOKIE: Option[String] = { val cookie = config.getString("akka.remote.secure-cookie", "") if (cookie == "") None else Some(cookie) } - - val READ_TIMEOUT = Duration(config.getInt("akka.remote.client.read-timeout", 1), TIME_UNIT) + + val READ_TIMEOUT = Duration(config.getInt("akka.remote.client.read-timeout", 1), TIME_UNIT) val RECONNECT_DELAY = Duration(config.getInt("akka.remote.client.reconnect-delay", 5), TIME_UNIT) private val remoteClients = new HashMap[String, RemoteClient] - private val remoteActors = new HashMap[Address, HashSet[Uuid]] + private val remoteActors = new HashMap[Address, HashSet[Uuid]] def actorFor(classNameOrServiceId: String, hostname: String, port: Int): ActorRef = actorFor(classNameOrServiceId, classNameOrServiceId, 5000L, hostname, port, None) @@ -92,23 +93,23 @@ object RemoteClient extends Logging { def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int): ActorRef = RemoteActorRef(serviceId, className, hostname, port, timeout, None) - def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int) : T = { + def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int): T = { typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, 5000L, hostname, port, None) } - def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int) : T = { + def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int): T = { typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, None) } - def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = { + def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): T = { typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, Some(loader)) } - def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = { + def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): T = { typedActorFor(intfClass, serviceId, implClassName, timeout, hostname, port, Some(loader)) } - private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]) : T = { + private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]): T = { val actorRef = RemoteActorRef(serviceId, implClassName, hostname, port, timeout, loader, ActorType.TypedActor) TypedActor.createProxyForRemoteActorRef(intfClass, actorRef) } @@ -164,7 +165,7 @@ object RemoteClient extends Logging { * Clean-up all open connections. */ def shutdownAll = synchronized { - remoteClients.foreach({case (addr, client) => client.shutdown}) + remoteClients.foreach({ case (addr, client) => client.shutdown }) remoteClients.clear } @@ -206,34 +207,40 @@ class RemoteClient private[akka] ( private val remoteAddress = new InetSocketAddress(hostname, port) //FIXME rewrite to a wrapper object (minimize volatile access and maximize encapsulation) - @volatile private var bootstrap: ClientBootstrap = _ - @volatile private[remote] var connection: ChannelFuture = _ - @volatile private[remote] var openChannels: DefaultChannelGroup = _ - @volatile private var timer: HashedWheelTimer = _ + @volatile + private var bootstrap: ClientBootstrap = _ + @volatile + private[remote] var connection: ChannelFuture = _ + @volatile + private[remote] var openChannels: DefaultChannelGroup = _ + @volatile + private var timer: HashedWheelTimer = _ private[remote] val runSwitch = new Switch() - + private[remote] val isAuthenticated = new AtomicBoolean(false) + private[remote] def isRunning = runSwitch.isOn private val reconnectionTimeWindow = Duration(config.getInt( "akka.remote.client.reconnection-time-window", 600), TIME_UNIT).toMillis - @volatile private var reconnectionTimeWindowStart = 0L + @volatile + private var reconnectionTimeWindowStart = 0L def connect = runSwitch switchOn { openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName) timer = new HashedWheelTimer - bootstrap = new ClientBootstrap( - new NioClientSocketChannelFactory( - Executors.newCachedThreadPool,Executors.newCachedThreadPool - ) - ) + + bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) bootstrap.setPipelineFactory(new RemoteClientPipelineFactory(name, futures, supervisors, bootstrap, remoteAddress, timer, this)) bootstrap.setOption("tcpNoDelay", true) bootstrap.setOption("keepAlive", true) - connection = bootstrap.connect(remoteAddress) + log.info("Starting remote client connection to [%s:%s]", hostname, port) + // Wait until the connection attempt succeeds or fails. + connection = bootstrap.connect(remoteAddress) val channel = connection.awaitUninterruptibly.getChannel openChannels.add(channel) + if (!connection.isSuccess) { notifyListeners(RemoteClientError(connection.getCause, this)) log.error(connection.getCause, "Remote client connection to [%s:%s] has failed", hostname, port) @@ -274,31 +281,34 @@ class RemoteClient private[akka] ( actorRef: ActorRef, typedActorInfo: Option[Tuple2[String, String]], actorType: ActorType): Option[CompletableFuture[T]] = { + val cookie = if (isAuthenticated.compareAndSet(false, true)) RemoteClient.SECURE_COOKIE + else None send(createRemoteRequestProtocolBuilder( - actorRef, message, isOneWay, senderOption, typedActorInfo, actorType, RemoteClient.SECURE_COOKIE).build, senderFuture) - } + actorRef, message, isOneWay, senderOption, typedActorInfo, actorType, cookie).build, senderFuture) + } def send[T]( request: RemoteRequestProtocol, - senderFuture: Option[CompletableFuture[T]]): - Option[CompletableFuture[T]] = if (isRunning) { - if (request.getIsOneWay) { - connection.getChannel.write(request) - None - } else { - futures.synchronized { - val futureResult = if (senderFuture.isDefined) senderFuture.get - else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) - futures.put(uuidFrom(request.getUuid.getHigh,request.getUuid.getLow), futureResult) + senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = { + if (isRunning) { + if (request.getIsOneWay) { connection.getChannel.write(request) - Some(futureResult) + None + } else { + futures.synchronized { + val futureResult = if (senderFuture.isDefined) senderFuture.get + else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) + futures.put(uuidFrom(request.getUuid.getHigh, request.getUuid.getLow), futureResult) + connection.getChannel.write(request) + Some(futureResult) + } } + } else { + val exception = new RemoteClientException( + "Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this) + notifyListeners(RemoteClientError(exception, this)) + throw exception } - } else { - val exception = new RemoteClientException( - "Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this) - notifyListeners(RemoteClientError(exception, this)) - throw exception } private[akka] def registerSupervisorForActor(actorRef: ActorRef) = @@ -331,13 +341,13 @@ class RemoteClient private[akka] ( * @author Jonas Bonér */ class RemoteClientPipelineFactory( - name: String, - futures: ConcurrentMap[Uuid, CompletableFuture[_]], - supervisors: ConcurrentMap[Uuid, ActorRef], - bootstrap: ClientBootstrap, - remoteAddress: SocketAddress, - timer: HashedWheelTimer, - client: RemoteClient) extends ChannelPipelineFactory { + name: String, + futures: ConcurrentMap[Uuid, CompletableFuture[_]], + supervisors: ConcurrentMap[Uuid, ActorRef], + bootstrap: ClientBootstrap, + remoteAddress: SocketAddress, + timer: HashedWheelTimer, + client: RemoteClient) extends ChannelPipelineFactory { def getPipeline: ChannelPipeline = { def join(ch: ChannelHandler*) = Array[ChannelHandler](ch: _*) @@ -349,15 +359,15 @@ class RemoteClientPipelineFactory( e } - val ssl = if (RemoteServer.SECURE) join(new SslHandler(engine)) else join() - val timeout = new ReadTimeoutHandler(timer, RemoteClient.READ_TIMEOUT.toMillis.toInt) - val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4) - val lenPrep = new LengthFieldPrepender(4) + val ssl = if (RemoteServer.SECURE) join(new SslHandler(engine)) else join() + val timeout = new ReadTimeoutHandler(timer, RemoteClient.READ_TIMEOUT.toMillis.toInt) + val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4) + val lenPrep = new LengthFieldPrepender(4) val protobufDec = new ProtobufDecoder(RemoteReplyProtocol.getDefaultInstance) val protobufEnc = new ProtobufEncoder - val (enc, dec) = RemoteServer.COMPRESSION_SCHEME match { + val (enc, dec) = RemoteServer.COMPRESSION_SCHEME match { case "zlib" => (join(new ZlibEncoder(RemoteServer.ZLIB_COMPRESSION_LEVEL)), join(new ZlibDecoder)) - case _ => (join(), join()) + case _ => (join(), join()) } val remoteClient = new RemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client) @@ -371,18 +381,18 @@ class RemoteClientPipelineFactory( */ @ChannelHandler.Sharable class RemoteClientHandler( - val name: String, - val futures: ConcurrentMap[Uuid, CompletableFuture[_]], - val supervisors: ConcurrentMap[Uuid, ActorRef], - val bootstrap: ClientBootstrap, - val remoteAddress: SocketAddress, - val timer: HashedWheelTimer, - val client: RemoteClient) - extends SimpleChannelUpstreamHandler with Logging { + val name: String, + val futures: ConcurrentMap[Uuid, CompletableFuture[_]], + val supervisors: ConcurrentMap[Uuid, ActorRef], + val bootstrap: ClientBootstrap, + val remoteAddress: SocketAddress, + val timer: HashedWheelTimer, + val client: RemoteClient) + extends SimpleChannelUpstreamHandler with Logging { override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = { if (event.isInstanceOf[ChannelStateEvent] && - event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) { + event.asInstanceOf[ChannelStateEvent].getState != ChannelState.INTEREST_OPS) { log.debug(event.toString) } super.handleUpstream(ctx, event) @@ -393,7 +403,7 @@ class RemoteClientHandler( val result = event.getMessage if (result.isInstanceOf[RemoteReplyProtocol]) { val reply = result.asInstanceOf[RemoteReplyProtocol] - val replyUuid = uuidFrom(reply.getUuid.getHigh,reply.getUuid.getLow) + val replyUuid = uuidFrom(reply.getUuid.getHigh, reply.getUuid.getLow) log.debug("Remote client received RemoteReplyProtocol[\n%s]", reply.toString) val future = futures.get(replyUuid).asInstanceOf[CompletableFuture[Any]] if (reply.getIsSuccessful) { @@ -401,7 +411,7 @@ class RemoteClientHandler( future.completeWithResult(message) } else { if (reply.hasSupervisorUuid()) { - val supervisorUuid = uuidFrom(reply.getSupervisorUuid.getHigh,reply.getSupervisorUuid.getLow) + val supervisorUuid = uuidFrom(reply.getSupervisorUuid.getHigh, reply.getSupervisorUuid.getLow) if (!supervisors.containsKey(supervisorUuid)) throw new IllegalActorStateException( "Expected a registered supervisor for UUID [" + supervisorUuid + "] but none was found") val supervisedActor = supervisors.get(supervisorUuid) @@ -430,6 +440,7 @@ class RemoteClientHandler( timer.newTimeout(new TimerTask() { def run(timeout: Timeout) = { client.openChannels.remove(event.getChannel) + client.isAuthenticated.set(false) log.debug("Remote client reconnecting to [%s]", remoteAddress) client.connection = bootstrap.connect(remoteAddress) client.connection.awaitUninterruptibly // Wait until the connection attempt succeeds or fails. @@ -475,9 +486,9 @@ class RemoteClientHandler( val exception = reply.getException val classname = exception.getClassname val exceptionClass = if (loader.isDefined) loader.get.loadClass(classname) - else Class.forName(classname) + else Class.forName(classname) exceptionClass - .getConstructor(Array[Class[_]](classOf[String]): _*) - .newInstance(exception.getMessage).asInstanceOf[Throwable] + .getConstructor(Array[Class[_]](classOf[String]): _*) + .newInstance(exception.getMessage).asInstanceOf[Throwable] } } diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index 0d39be263a..bcc1684eed 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -212,13 +212,14 @@ class RemoteServer extends Logging with ListenerManagement { address = Address(_hostname,_port) log.info("Starting remote server at [%s:%s]", hostname, port) RemoteServer.register(hostname, port, this) - val pipelineFactory = new RemoteServerPipelineFactory( - name, openChannels, loader, this) + + val pipelineFactory = new RemoteServerPipelineFactory(name, openChannels, loader, this) bootstrap.setPipelineFactory(pipelineFactory) bootstrap.setOption("child.tcpNoDelay", true) bootstrap.setOption("child.keepAlive", true) bootstrap.setOption("child.reuseAddress", true) bootstrap.setOption("child.connectTimeoutMillis", RemoteServer.CONNECTION_TIMEOUT_MILLIS.toMillis) + openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port))) _isRunning = true Cluster.registerLocalNode(hostname, port) @@ -260,11 +261,8 @@ class RemoteServer extends Logging with ListenerManagement { */ def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized { log.debug("Registering server side remote typed actor [%s] with id [%s]", typedActor.getClass.getName, id) - if (id.startsWith(UUID_PREFIX)) { - registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid()) - } else { - registerTypedActor(id, typedActor, typedActors()) - } + if (id.startsWith(UUID_PREFIX)) registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid) + else registerTypedActor(id, typedActor, typedActors) } /** @@ -279,28 +277,19 @@ class RemoteServer extends Logging with ListenerManagement { */ def register(id: String, actorRef: ActorRef): Unit = synchronized { log.debug("Registering server side remote actor [%s] with id [%s]", actorRef.actorClass.getName, id) - if (id.startsWith(UUID_PREFIX)) { - register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid()) - } else { - register(id, actorRef, actors()) - } + if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid) + else register(id, actorRef, actors) } private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) { - if (_isRunning) { - if (!registry.contains(id)) { - if (!actorRef.isRunning) actorRef.start - registry.put(id, actorRef) - } + if (_isRunning && !registry.contains(id)) { + if (!actorRef.isRunning) actorRef.start + registry.put(id, actorRef) } } private def registerTypedActor[Key](id: Key, typedActor: AnyRef, registry: ConcurrentHashMap[Key, AnyRef]) { - if (_isRunning) { - if (!registry.contains(id)) { - registry.put(id, typedActor) - } - } + if (_isRunning && !registry.contains(id)) registry.put(id, typedActor) } /** @@ -309,8 +298,8 @@ class RemoteServer extends Logging with ListenerManagement { def unregister(actorRef: ActorRef):Unit = synchronized { if (_isRunning) { log.debug("Unregistering server side remote actor [%s] with id [%s:%s]", actorRef.actorClass.getName, actorRef.id, actorRef.uuid) - actors().remove(actorRef.id,actorRef) - actorsByUuid().remove(actorRef.uuid,actorRef) + actors.remove(actorRef.id, actorRef) + actorsByUuid.remove(actorRef.uuid, actorRef) } } @@ -322,12 +311,11 @@ class RemoteServer extends Logging with ListenerManagement { def unregister(id: String):Unit = synchronized { if (_isRunning) { log.info("Unregistering server side remote actor with id [%s]", id) - if (id.startsWith(UUID_PREFIX)) { - actorsByUuid().remove(id.substring(UUID_PREFIX.length)) - } else { - val actorRef = actors() get id - actorsByUuid().remove(actorRef.uuid,actorRef) - actors().remove(id,actorRef) + if (id.startsWith(UUID_PREFIX)) actorsByUuid.remove(id.substring(UUID_PREFIX.length)) + else { + val actorRef = actors get id + actorsByUuid.remove(actorRef.uuid, actorRef) + actors.remove(id,actorRef) } } } @@ -340,11 +328,8 @@ class RemoteServer extends Logging with ListenerManagement { def unregisterTypedActor(id: String):Unit = synchronized { if (_isRunning) { log.info("Unregistering server side remote typed actor with id [%s]", id) - if (id.startsWith(UUID_PREFIX)) { - typedActorsByUuid().remove(id.substring(UUID_PREFIX.length)) - } else { - typedActors().remove(id) - } + if (id.startsWith(UUID_PREFIX)) typedActorsByUuid.remove(id.substring(UUID_PREFIX.length)) + else typedActors.remove(id) } } @@ -352,10 +337,10 @@ class RemoteServer extends Logging with ListenerManagement { protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message) - private[akka] def actors() = ActorRegistry.actors(address) - private[akka] def actorsByUuid() = ActorRegistry.actorsByUuid(address) - private[akka] def typedActors() = ActorRegistry.typedActors(address) - private[akka] def typedActorsByUuid() = ActorRegistry.typedActorsByUuid(address) + private[akka] def actors = ActorRegistry.actors(address) + private[akka] def actorsByUuid = ActorRegistry.actorsByUuid(address) + private[akka] def typedActors = ActorRegistry.typedActors(address) + private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address) } object RemoteServerSslContext { @@ -419,6 +404,7 @@ class RemoteServerHandler( val applicationLoader: Option[ClassLoader], val server: RemoteServer) extends SimpleChannelUpstreamHandler with Logging { import RemoteServer._ + val AW_PROXY_PREFIX = "$$ProxiedByAW".intern val CHANNEL_INIT = "channel-init".intern @@ -444,10 +430,8 @@ class RemoteServerHandler( } else future.getChannel.close } }) - } else { - server.notifyListeners(RemoteServerClientConnected(server)) - } - if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) + } else server.notifyListeners(RemoteServerClientConnected(server)) + if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication } override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { @@ -467,7 +451,7 @@ class RemoteServerHandler( if (message eq null) throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event) if (message.isInstanceOf[RemoteRequestProtocol]) { val requestProtocol = message.asInstanceOf[RemoteRequestProtocol] - authenticateRemoteClient(requestProtocol, ctx) + if (RemoteServer.REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx) handleRemoteRequestProtocol(requestProtocol, event.getChannel) } } @@ -521,8 +505,7 @@ class RemoteServerHandler( try { channel.write(replyBuilder.build) } catch { - case e: Throwable => - server.notifyListeners(RemoteServerError(e, server)) + case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) } } @@ -530,8 +513,7 @@ class RemoteServerHandler( try { channel.write(createErrorReplyMessage(exception, request, true)) } catch { - case e: Throwable => - server.notifyListeners(RemoteServerError(e, server)) + case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) } } } @@ -543,8 +525,8 @@ class RemoteServerHandler( val actorInfo = request.getActorInfo val typedActorInfo = actorInfo.getTypedActorInfo log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface) - val typedActor = createTypedActor(actorInfo) + val typedActor = createTypedActor(actorInfo) val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList val argClasses = args.map(_.getClass) @@ -566,49 +548,39 @@ class RemoteServerHandler( case e: InvocationTargetException => channel.write(createErrorReplyMessage(e.getCause, request, false)) server.notifyListeners(RemoteServerError(e, server)) - case e: Throwable => + case e: Throwable => channel.write(createErrorReplyMessage(e, request, false)) server.notifyListeners(RemoteServerError(e, server)) } } private def findActorById(id: String) : ActorRef = { - server.actors().get(id) + server.actors.get(id) } private def findActorByUuid(uuid: String) : ActorRef = { - server.actorsByUuid().get(uuid) + server.actorsByUuid.get(uuid) } private def findTypedActorById(id: String) : AnyRef = { - server.typedActors().get(id) + server.typedActors.get(id) } private def findTypedActorByUuid(uuid: String) : AnyRef = { - server.typedActorsByUuid().get(uuid) + server.typedActorsByUuid.get(uuid) } private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = { - var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) { - findActorByUuid(id.substring(UUID_PREFIX.length)) - } else { - findActorById(id) - } - if (actorRefOrNull eq null) { - actorRefOrNull = findActorByUuid(uuid) - } + var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findActorByUuid(id.substring(UUID_PREFIX.length)) + else findActorById(id) + if (actorRefOrNull eq null) actorRefOrNull = findActorByUuid(uuid) actorRefOrNull } private def findTypedActorByIdOrUuid(id: String, uuid: String) : AnyRef = { - var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) { - findTypedActorByUuid(id.substring(UUID_PREFIX.length)) - } else { - findTypedActorById(id) - } - if (actorRefOrNull eq null) { - actorRefOrNull = findTypedActorByUuid(uuid) - } + var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findTypedActorByUuid(id.substring(UUID_PREFIX.length)) + else findTypedActorById(id) + if (actorRefOrNull eq null) actorRefOrNull = findTypedActorByUuid(uuid) actorRefOrNull } @@ -694,18 +666,17 @@ class RemoteServerHandler( } private def authenticateRemoteClient(request: RemoteRequestProtocol, ctx: ChannelHandlerContext) = { - if (RemoteServer.REQUIRE_COOKIE) { - val attachment = ctx.getAttachment - if ((attachment ne null) && - attachment.isInstanceOf[String] && - attachment.asInstanceOf[String] == CHANNEL_INIT) { - val clientAddress = ctx.getChannel.getRemoteAddress.toString - if (!request.hasCookie) throw new SecurityException( - "The remote client [" + clientAddress + "] does not have a secure cookie.") - if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException( - "The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie") - log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress) - } + val attachment = ctx.getAttachment + if ((attachment ne null) && + attachment.isInstanceOf[String] && + attachment.asInstanceOf[String] == CHANNEL_INIT) { // is first time around, channel initialization + ctx.setAttachment(null) + val clientAddress = ctx.getChannel.getRemoteAddress.toString + if (!request.hasCookie) throw new SecurityException( + "The remote client [" + clientAddress + "] does not have a secure cookie.") + if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException( + "The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie") + log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress) } } } diff --git a/akka-remote/src/main/scala/serialization/SerializationProtocol.scala b/akka-remote/src/main/scala/serialization/SerializationProtocol.scala index 8383607f8c..9cbf2446aa 100644 --- a/akka-remote/src/main/scala/serialization/SerializationProtocol.scala +++ b/akka-remote/src/main/scala/serialization/SerializationProtocol.scala @@ -249,11 +249,11 @@ object RemoteActorSerialization { ActorRegistry.registerActorByUuid(homeAddress, uuid.toString, ar) RemoteActorRefProtocol.newBuilder - .setClassOrServiceName(uuid.toString) - .setActorClassname(actorClassName) - .setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build) - .setTimeout(timeout) - .build + .setClassOrServiceName(uuid.toString) + .setActorClassname(actorClassName) + .setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build) + .setTimeout(timeout) + .build } def createRemoteRequestProtocolBuilder( @@ -263,8 +263,7 @@ object RemoteActorSerialization { senderOption: Option[ActorRef], typedActorInfo: Option[Tuple2[String, String]], actorType: ActorType, - secureCookie: Option[String]): - RemoteRequestProtocol.Builder = { + secureCookie: Option[String]): RemoteRequestProtocol.Builder = { import actorRef._ val actorInfoBuilder = ActorInfoProtocol.newBuilder @@ -273,13 +272,12 @@ object RemoteActorSerialization { .setTarget(actorClassName) .setTimeout(timeout) - typedActorInfo.foreach { - typedActor => - actorInfoBuilder.setTypedActorInfo( - TypedActorInfoProtocol.newBuilder - .setInterface(typedActor._1) - .setMethod(typedActor._2) - .build) + typedActorInfo.foreach { typedActor => + actorInfoBuilder.setTypedActorInfo( + TypedActorInfoProtocol.newBuilder + .setInterface(typedActor._1) + .setMethod(typedActor._2) + .build) } actorType match { @@ -310,8 +308,6 @@ object RemoteActorSerialization { } requestBuilder } - - } @@ -408,5 +404,4 @@ object RemoteTypedActorSerialization { .setInterfaceName(init.interfaceClass.getName) .build } - } diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala index b0cbc5ec08..070682c794 100644 --- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala @@ -201,18 +201,18 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { def shouldRegisterAndUnregister { val actor1 = actorOf[RemoteActorSpecActorUnidirectional] server.register("my-service-1", actor1) - assert(server.actors().get("my-service-1") ne null, "actor registered") + assert(server.actors.get("my-service-1") ne null, "actor registered") server.unregister("my-service-1") - assert(server.actors().get("my-service-1") eq null, "actor unregistered") + assert(server.actors.get("my-service-1") eq null, "actor unregistered") } @Test def shouldRegisterAndUnregisterByUuid { val actor1 = actorOf[RemoteActorSpecActorUnidirectional] server.register("uuid:" + actor1.uuid, actor1) - assert(server.actorsByUuid().get(actor1.uuid.toString) ne null, "actor registered") + assert(server.actorsByUuid.get(actor1.uuid.toString) ne null, "actor registered") server.unregister("uuid:" + actor1.uuid) - assert(server.actorsByUuid().get(actor1.uuid) eq null, "actor unregistered") + assert(server.actorsByUuid.get(actor1.uuid) eq null, "actor unregistered") } } diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala index 71ece9792e..cdb8cf5cf2 100644 --- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala @@ -103,9 +103,9 @@ class ServerInitiatedRemoteTypedActorSpec extends it("should register and unregister typed actors") { val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000) server.registerTypedActor("my-test-service", typedActor) - assert(server.typedActors().get("my-test-service") ne null, "typed actor registered") + assert(server.typedActors.get("my-test-service") ne null, "typed actor registered") server.unregisterTypedActor("my-test-service") - assert(server.typedActors().get("my-test-service") eq null, "typed actor unregistered") + assert(server.typedActors.get("my-test-service") eq null, "typed actor unregistered") } it("should register and unregister typed actors by uuid") { @@ -113,9 +113,9 @@ class ServerInitiatedRemoteTypedActorSpec extends val init = AspectInitRegistry.initFor(typedActor) val uuid = "uuid:" + init.actorRef.uuid server.registerTypedActor(uuid, typedActor) - assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) ne null, "typed actor registered") + assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) ne null, "typed actor registered") server.unregisterTypedActor(uuid) - assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) eq null, "typed actor unregistered") + assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) eq null, "typed actor unregistered") } it("should find typed actors by uuid") { @@ -123,7 +123,7 @@ class ServerInitiatedRemoteTypedActorSpec extends val init = AspectInitRegistry.initFor(typedActor) val uuid = "uuid:" + init.actorRef.uuid server.registerTypedActor(uuid, typedActor) - assert(server.typedActorsByUuid().get(init.actorRef.uuid.toString) ne null, "typed actor registered") + assert(server.typedActorsByUuid.get(init.actorRef.uuid.toString) ne null, "typed actor registered") val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], uuid, HOSTNAME, PORT) expect("oneway") { diff --git a/akka-sbt-plugin/src/main/scala/AkkaProject.scala b/akka-sbt-plugin/src/main/scala/AkkaProject.scala index 2bde073df8..82ccbe401a 100644 --- a/akka-sbt-plugin/src/main/scala/AkkaProject.scala +++ b/akka-sbt-plugin/src/main/scala/AkkaProject.scala @@ -4,7 +4,7 @@ object AkkaRepositories { val AkkaRepo = MavenRepository("Akka Repository", "http://scalablesolutions.se/akka/repository") val CodehausRepo = MavenRepository("Codehaus Repo", "http://repository.codehaus.org") val GuiceyFruitRepo = MavenRepository("GuiceyFruit Repo", "http://guiceyfruit.googlecode.com/svn/repo/releases/") - val JBossRepo = MavenRepository("JBoss Repo", "https://repository.jboss.org/nexus/content/groups/public/") + val JBossRepo = MavenRepository("JBoss Repo", "http://repository.jboss.org/nexus/content/groups/public/") val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2") val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases") val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo") diff --git a/project/build/AkkaProject.scala b/project/build/AkkaProject.scala index 5b93e30044..6add9542ee 100644 --- a/project/build/AkkaProject.scala +++ b/project/build/AkkaProject.scala @@ -72,7 +72,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) { lazy val EmbeddedRepo = MavenRepository("Embedded Repo", (info.projectPath / "embedded-repo").asURL.toString) lazy val FusesourceSnapshotRepo = MavenRepository("Fusesource Snapshots", "http://repo.fusesource.com/nexus/content/repositories/snapshots") lazy val GuiceyFruitRepo = MavenRepository("GuiceyFruit Repo", "http://guiceyfruit.googlecode.com/svn/repo/releases/") - lazy val JBossRepo = MavenRepository("JBoss Repo", "https://repository.jboss.org/nexus/content/groups/public/") + lazy val JBossRepo = MavenRepository("JBoss Repo", "http://repository.jboss.org/nexus/content/groups/public/") lazy val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2") lazy val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases") lazy val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo")