diff --git a/akka-actor/src/main/scala/actor/ActorRegistry.scala b/akka-actor/src/main/scala/actor/ActorRegistry.scala index 0b946df99f..f8b46fd3c4 100644 --- a/akka-actor/src/main/scala/actor/ActorRegistry.scala +++ b/akka-actor/src/main/scala/actor/ActorRegistry.scala @@ -304,6 +304,7 @@ object ActorRegistry extends ListenerManagement { private[akka] def actorsFactories(address: Address) = actorsFor(address).actorsFactories private[akka] def typedActors(address: Address) = actorsFor(address).typedActors private[akka] def typedActorsByUuid(address: Address) = actorsFor(address).typedActorsByUuid + private[akka] def typedActorsFactories(address: Address) = actorsFor(address).typedActorsFactories private[akka] class RemoteActorSet { private[ActorRegistry] val actors = new ConcurrentHashMap[String, ActorRef] @@ -311,6 +312,7 @@ object ActorRegistry extends ListenerManagement { private[ActorRegistry] val actorsFactories = new ConcurrentHashMap[String, () => ActorRef] private[ActorRegistry] val typedActors = new ConcurrentHashMap[String, AnyRef] private[ActorRegistry] val typedActorsByUuid = new ConcurrentHashMap[String, AnyRef] + private[ActorRegistry] val typedActorsFactories = new ConcurrentHashMap[String, () => AnyRef] } } diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index 5ba58567ef..9afef140b1 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -284,6 +284,21 @@ class RemoteServer extends Logging with ListenerManagement { else registerTypedActor(id, typedActor, typedActors) } + /** + * Register typed actor by interface name. + */ + def registerTypedPerSessionActor(intfClass: Class[_], factory: => AnyRef) : Unit = registerTypedActor(intfClass.getName, factory) + + /** + * Register remote typed actor by a specific id. + * @param id custom actor id + * @param typedActor typed actor to register + */ + def registerTypedPerSessionActor(id: String, factory: => AnyRef): Unit = synchronized { + log.debug("Registering server side typed remote session actor with id [%s]", id) + registerTypedPerSessionActor(id, () => factory, typedActorsFactories) + } + /** * Register Remote Actor by the Actor's 'id' field. It starts the Actor if it is not started already. */ @@ -307,8 +322,7 @@ class RemoteServer extends Logging with ListenerManagement { */ def registerPerSession(id: String, factory: => ActorRef): Unit = synchronized { log.debug("Registering server side remote session actor with id [%s]", id) - if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), factory, actorsByUuid) - else registerPerSession(id, () => factory, actorsFactories) + registerPerSession(id, () => factory, actorsFactories) } private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) { @@ -328,6 +342,12 @@ class RemoteServer extends Logging with ListenerManagement { if (_isRunning && !registry.contains(id)) registry.put(id, typedActor) } + private def registerTypedPerSessionActor[Key](id: Key, factory: () => AnyRef, registry: ConcurrentHashMap[Key,() => AnyRef]) { + if (_isRunning && !registry.contains(id)) { + registry.put(id, factory) + } + } + /** * Unregister Remote Actor that is registered using its 'id' field (not custom ID). */ @@ -381,6 +401,17 @@ class RemoteServer extends Logging with ListenerManagement { } } + /** + * Unregister Remote Typed Actor by specific 'id'. + *
+ * NOTE: You need to call this method if you have registered an actor by a custom ID. + */ + def unregisterTypedPerSessionActor(id: String):Unit = synchronized { + if (_isRunning) { + typedActorsFactories.remove(id) + } + } + protected override def manageLifeCycleOfListeners = false protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message) @@ -391,6 +422,7 @@ class RemoteServer extends Logging with ListenerManagement { private[akka] def actorsFactories = ActorRegistry.actorsFactories(address) private[akka] def typedActors = ActorRegistry.typedActors(address) private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address) + private[akka] def typedActorsFactories = ActorRegistry.typedActorsFactories(address) } object RemoteServerSslContext { @@ -459,6 +491,7 @@ class RemoteServerHandler( val CHANNEL_INIT = "channel-init".intern val sessionActors = new ChannelLocal[Map[String, ActorRef]](); + val typedSessionActors = new ChannelLocal[Map[String, AnyRef]](); applicationLoader.foreach(MessageSerializer.setClassLoader(_)) @@ -471,6 +504,7 @@ class RemoteServerHandler( override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) sessionActors.set(event.getChannel(), Map[String, ActorRef]()); + typedSessionActors.set(event.getChannel(), Map[String, AnyRef]()); log.debug("Remote client [%s] connected to [%s]", clientAddress, server.name) if (RemoteServer.SECURE) { val sslHandler: SslHandler = ctx.getPipeline.get(classOf[SslHandler]) @@ -495,6 +529,11 @@ class RemoteServerHandler( actorRef.stop() } sessionActors.remove(event.getChannel()); + for ((id, actorRef) <- typedSessionActors.get(event.getChannel())) { + TypedActor.stop(actorRef) + } + typedSessionActors.remove(event.getChannel()); + server.notifyListeners(RemoteServerClientDisconnected(server, clientAddress)) } @@ -623,7 +662,7 @@ class RemoteServerHandler( val typedActorInfo = actorInfo.getTypedActorInfo log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface) - val typedActor = createTypedActor(actorInfo) + val typedActor = createTypedActor(actorInfo, channel) val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList val argClasses = args.map(_.getClass) @@ -673,6 +712,7 @@ class RemoteServerHandler( private def findActorFactory(id: String) : () => ActorRef = { server.actorsFactories.get(id) } + private def findSessionActor(id: String, channel: Channel) : ActorRef = { sessionActors.get(channel).getOrElse(id, null) } @@ -681,6 +721,14 @@ class RemoteServerHandler( server.typedActors.get(id) } + private def findTypedActorFactory(id: String) : () => AnyRef = { + server.typedActorsFactories.get(id) + } + + private def findTypedSessionActor(id: String, channel: Channel) : AnyRef = { + typedSessionActors.get(channel).getOrElse(id, null) + } + private def findTypedActorByUuid(uuid: String) : AnyRef = { server.typedActorsByUuid.get(uuid) } @@ -731,7 +779,6 @@ class RemoteServerHandler( val actorRef = actorFactoryOrNull(); actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) sessionActors.get(channel).put(id, actorRef); - server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid return actorRef } @@ -758,39 +805,55 @@ class RemoteServerHandler( } } - private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = { + private def createTypedActor(actorInfo: ActorInfoProtocol, channel: Channel): AnyRef = { val uuid = actorInfo.getUuid val id = actorInfo.getId val typedActorOrNull = findTypedActorByIdOrUuid(id, uuidFrom(uuid.getHigh,uuid.getLow).toString) + if (typedActorOrNull ne null) + return typedActorOrNull; - if (typedActorOrNull eq null) { - val typedActorInfo = actorInfo.getTypedActorInfo - val interfaceClassname = typedActorInfo.getInterface - val targetClassname = actorInfo.getTarget + // the actor has not been registered globally. See if we have it in the session - try { - if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( - "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + val sessionActorRefOrNull = findTypedSessionActor(id, channel); + if (sessionActorRefOrNull ne null) + return sessionActorRefOrNull - log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) + // we dont have it in the session either, see if we have a factory for it + val actorFactoryOrNull = findTypedActorFactory(id) + if (actorFactoryOrNull ne null) { + val newInstance = actorFactoryOrNull(); + typedSessionActors.get(channel).put(id, newInstance); + return newInstance + } - val (interfaceClass, targetClass) = - if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), - applicationLoader.get.loadClass(targetClassname)) - else (Class.forName(interfaceClassname), Class.forName(targetClassname)) + // None of the above, so treat it as a client managed remote actor - val newInstance = TypedActor.newInstance( - interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] - server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid - newInstance - } catch { - case e => - log.error(e, "Could not create remote typed actor instance") - server.notifyListeners(RemoteServerError(e, server)) - throw e - } - } else typedActorOrNull + val typedActorInfo = actorInfo.getTypedActorInfo + val interfaceClassname = typedActorInfo.getInterface + val targetClassname = actorInfo.getTarget + + try { + if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( + "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + + log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) + + val (interfaceClass, targetClass) = + if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), + applicationLoader.get.loadClass(targetClassname)) + else (Class.forName(interfaceClassname), Class.forName(targetClassname)) + + val newInstance = TypedActor.newInstance( + interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] + server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid + newInstance + } catch { + case e => + log.error(e, "Could not create remote typed actor instance") + server.notifyListeners(RemoteServerError(e, server)) + throw e + } } private def createErrorReplyMessage(exception: Throwable, request: RemoteMessageProtocol, actorType: AkkaActorType): RemoteMessageProtocol = { diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java new file mode 100644 index 0000000000..8a6c2e6373 --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java @@ -0,0 +1,8 @@ +package akka.actor; + +public interface RemoteTypedSessionActor { + + public void login(String user); + public String getUser(); + public void doSomethingFunny() throws Exception; +} diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java new file mode 100644 index 0000000000..b4140f74ed --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java @@ -0,0 +1,49 @@ +package akka.actor.remote; + +import akka.actor.*; + +import java.util.Set; +import java.util.HashSet; + +import java.util.concurrent.CountDownLatch; + +public class RemoteTypedSessionActorImpl extends TypedActor implements RemoteTypedSessionActor { + + + private static Set