diff --git a/akka-actor/src/main/scala/akka/actor/ActorRegistry.scala b/akka-actor/src/main/scala/akka/actor/ActorRegistry.scala index aebfe70d85..3476304a67 100644 --- a/akka-actor/src/main/scala/akka/actor/ActorRegistry.scala +++ b/akka-actor/src/main/scala/akka/actor/ActorRegistry.scala @@ -286,14 +286,18 @@ object ActorRegistry extends ListenerManagement { private[akka] def actors(address: Address) = actorsFor(address).actors private[akka] def actorsByUuid(address: Address) = actorsFor(address).actorsByUuid + private[akka] def actorsFactories(address: Address) = actorsFor(address).actorsFactories private[akka] def typedActors(address: Address) = actorsFor(address).typedActors private[akka] def typedActorsByUuid(address: Address) = actorsFor(address).typedActorsByUuid + private[akka] def typedActorsFactories(address: Address) = actorsFor(address).typedActorsFactories private[akka] class RemoteActorSet { private[ActorRegistry] val actors = new ConcurrentHashMap[String, ActorRef] private[ActorRegistry] val actorsByUuid = new ConcurrentHashMap[String, ActorRef] + private[ActorRegistry] val actorsFactories = new ConcurrentHashMap[String, () => ActorRef] private[ActorRegistry] val typedActors = new ConcurrentHashMap[String, AnyRef] private[ActorRegistry] val typedActorsByUuid = new ConcurrentHashMap[String, AnyRef] + private[ActorRegistry] val typedActorsFactories = new ConcurrentHashMap[String, () => AnyRef] } } diff --git a/akka-remote/src/main/scala/akka/remote/RemoteClient.scala b/akka-remote/src/main/scala/akka/remote/RemoteClient.scala index e80c7936c2..77052c8547 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteClient.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteClient.scala @@ -417,6 +417,7 @@ class RemoteClientHandler( if (result.isInstanceOf[RemoteMessageProtocol]) { val reply = result.asInstanceOf[RemoteMessageProtocol] val replyUuid = uuidFrom(reply.getUuid.getHigh, reply.getUuid.getLow) +// log.debug("Remote client received RemoteMessageProtocol[\n%s]".format(request.toString)) log.debug("Remote client received RemoteMessageProtocol[\n%s]", reply.toString) val future = futures.get(replyUuid).asInstanceOf[CompletableFuture[Any]] if (reply.hasMessage) { diff --git a/akka-remote/src/main/scala/akka/remote/RemoteServer.scala b/akka-remote/src/main/scala/akka/remote/RemoteServer.scala index 33b71f1425..a7c23a8d06 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteServer.scala @@ -286,6 +286,21 @@ class RemoteServer extends Logging with ListenerManagement { else registerTypedActor(id, typedActor, typedActors) } + /** + * Register typed actor by interface name. + */ + def registerTypedPerSessionActor(intfClass: Class[_], factory: => AnyRef) : Unit = registerTypedActor(intfClass.getName, factory) + + /** + * Register remote typed actor by a specific id. + * @param id custom actor id + * @param typedActor typed actor to register + */ + def registerTypedPerSessionActor(id: String, factory: => AnyRef): Unit = synchronized { + log.debug("Registering server side typed remote session actor with id [%s]", id) + registerTypedPerSessionActor(id, () => factory, typedActorsFactories) + } + /** * Register Remote Actor by the Actor's 'id' field. It starts the Actor if it is not started already. */ @@ -302,15 +317,36 @@ class RemoteServer extends Logging with ListenerManagement { else register(id, actorRef, actors) } + /** + * Register Remote Session Actor by a specific 'id' passed as argument. + *
+ * NOTE: If you use this method to register your remote actor then you must unregister the actor by this ID yourself. + */ + def registerPerSession(id: String, factory: => ActorRef): Unit = synchronized { + log.debug("Registering server side remote session actor with id [%s]", id) + registerPerSession(id, () => factory, actorsFactories) + } + private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) { - if (_isRunning && !registry.contains(id)) { + if (_isRunning) { + registry.put(id, actorRef) //TODO change to putIfAbsent if (!actorRef.isRunning) actorRef.start - registry.put(id, actorRef) } } + private def registerPerSession[Key](id: Key, factory: () => ActorRef, registry: ConcurrentHashMap[Key,() => ActorRef]) { + if (_isRunning) + registry.put(id, factory) //TODO change to putIfAbsent + } + private def registerTypedActor[Key](id: Key, typedActor: AnyRef, registry: ConcurrentHashMap[Key, AnyRef]) { - if (_isRunning && !registry.contains(id)) registry.put(id, typedActor) + if (_isRunning) + registry.put(id, typedActor) //TODO change to putIfAbsent + } + + private def registerTypedPerSessionActor[Key](id: Key, factory: () => AnyRef, registry: ConcurrentHashMap[Key,() => AnyRef]) { + if (_isRunning) + registry.put(id, factory) //TODO change to putIfAbsent } /** @@ -341,6 +377,18 @@ class RemoteServer extends Logging with ListenerManagement { } } + /** + * Unregister Remote Actor by specific 'id'. + * + * NOTE: You need to call this method if you have registered an actor by a custom ID. + */ + def unregisterPerSession(id: String):Unit = { + if (_isRunning) { + log.info("Unregistering server side remote session actor with id [%s]", id) + actorsFactories.remove(id) + } + } + /** * Unregister Remote Typed Actor by specific 'id'. * @@ -354,14 +402,28 @@ class RemoteServer extends Logging with ListenerManagement { } } + /** + * Unregister Remote Typed Actor by specific 'id'. + * + * NOTE: You need to call this method if you have registered an actor by a custom ID. + */ + def unregisterTypedPerSessionActor(id: String):Unit = { + if (_isRunning) { + typedActorsFactories.remove(id) + } + } + protected override def manageLifeCycleOfListeners = false protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message) + private[akka] def actors = ActorRegistry.actors(address) private[akka] def actorsByUuid = ActorRegistry.actorsByUuid(address) + private[akka] def actorsFactories = ActorRegistry.actorsFactories(address) private[akka] def typedActors = ActorRegistry.typedActors(address) private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address) + private[akka] def typedActorsFactories = ActorRegistry.typedActorsFactories(address) } object RemoteServerSslContext { @@ -429,6 +491,9 @@ class RemoteServerHandler( val AW_PROXY_PREFIX = "$$ProxiedByAW".intern val CHANNEL_INIT = "channel-init".intern + val sessionActors = new ChannelLocal[ConcurrentHashMap[String, ActorRef]]() + val typedSessionActors = new ChannelLocal[ConcurrentHashMap[String, AnyRef]]() + applicationLoader.foreach(MessageSerializer.setClassLoader(_)) /** @@ -439,6 +504,8 @@ class RemoteServerHandler( override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) + sessionActors.set(event.getChannel(), new ConcurrentHashMap[String, ActorRef]()) + typedSessionActors.set(event.getChannel(), new ConcurrentHashMap[String, AnyRef]()) log.debug("Remote client [%s] connected to [%s]", clientAddress, server.name) if (RemoteServer.SECURE) { val sslHandler: SslHandler = ctx.getPipeline.get(classOf[SslHandler]) @@ -458,6 +525,22 @@ class RemoteServerHandler( override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) log.debug("Remote client [%s] disconnected from [%s]", clientAddress, server.name) + // stop all session actors + val channelActors = sessionActors.remove(event.getChannel) + if (channelActors ne null) { + val channelActorsIterator = channelActors.elements + while (channelActorsIterator.hasMoreElements) { + channelActorsIterator.nextElement.stop + } + } + + val channelTypedActors = typedSessionActors.remove(event.getChannel) + if (channelTypedActors ne null) { + val channelTypedActorsIterator = channelTypedActors.elements + while (channelTypedActorsIterator.hasMoreElements) { + TypedActor.stop(channelTypedActorsIterator.nextElement) + } + } server.notifyListeners(RemoteServerClientDisconnected(server, clientAddress)) } @@ -512,7 +595,7 @@ class RemoteServerHandler( val actorRef = try { - createActor(actorInfo).start + createActor(actorInfo, channel).start } catch { case e: SecurityException => channel.write(createErrorReplyMessage(e, request, AkkaActorType.ScalaActor)) @@ -544,7 +627,7 @@ class RemoteServerHandler( val exception = f.exception if (exception.isDefined) { - log.debug("Returning exception from actor invocation [%s]".format(exception.get)) + log.debug("Returning exception from actor invocation [%s]",exception.get) try { channel.write(createErrorReplyMessage(exception.get, request, AkkaActorType.ScalaActor)) } catch { @@ -586,7 +669,7 @@ class RemoteServerHandler( val typedActorInfo = actorInfo.getTypedActorInfo log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface) - val typedActor = createTypedActor(actorInfo) + val typedActor = createTypedActor(actorInfo, channel) val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList val argClasses = args.map(_.getClass) @@ -594,26 +677,35 @@ class RemoteServerHandler( val messageReceiver = typedActor.getClass.getDeclaredMethod(typedActorInfo.getMethod, argClasses: _*) if (request.getOneWay) messageReceiver.invoke(typedActor, args: _*) else { - val result = messageReceiver.invoke(typedActor, args: _*) match { - case f: Future[_] => f.await.result.get //TODO replace this with a Listener on the Future to avoid blocking - case other => other + //Sends the response + def sendResponse(result: Either[Any,Throwable]): Unit = try { + val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( + None, + Right(request.getUuid), + actorInfo.getId, + actorInfo.getTarget, + actorInfo.getTimeout, + result, + true, + None, + None, + AkkaActorType.TypedActor, + None) + if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) + channel.write(messageBuilder.build) + log.debug("Returning result from remote typed actor invocation [%s]", result) + } catch { + case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) } - log.debug("Returning result from remote typed actor invocation [%s]", result) - val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( - None, - Right(request.getUuid), - actorInfo.getId, - actorInfo.getTarget, - actorInfo.getTimeout, - Left(result), - true, - None, - None, - AkkaActorType.TypedActor, - None) - if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) - channel.write(messageBuilder.build) + messageReceiver.invoke(typedActor, args: _*) match { + case f: Future[_] => //If it's a future, we can lift on that to defer the send to when the future is completed + f.onComplete( future => { + val result: Either[Any,Throwable] = if (future.exception.isDefined) Right(future.exception.get) else Left(future.result.get) + sendResponse(result) + }) + case other => sendResponse(Left(other)) + } } } catch { case e: InvocationTargetException => @@ -633,10 +725,26 @@ class RemoteServerHandler( server.actorsByUuid.get(uuid) } + private def findActorFactory(id: String) : () => ActorRef = { + server.actorsFactories.get(id) + } + + private def findSessionActor(id: String, channel: Channel) : ActorRef = { + sessionActors.get(channel).get(id) + } + private def findTypedActorById(id: String) : AnyRef = { server.typedActors.get(id) } + private def findTypedActorFactory(id: String) : () => AnyRef = { + server.typedActorsFactories.get(id) + } + + private def findTypedSessionActor(id: String, channel: Channel) : AnyRef = { + typedSessionActors.get(channel).get(id) + } + private def findTypedActorByUuid(uuid: String) : AnyRef = { server.typedActorsByUuid.get(uuid) } @@ -655,6 +763,60 @@ class RemoteServerHandler( actorRefOrNull } + /** + * gets the actor from the session, or creates one if there is a factory for it + */ + private def createSessionActor(actorInfo: ActorInfoProtocol, channel: Channel): ActorRef = { + val uuid = actorInfo.getUuid + val id = actorInfo.getId + val sessionActorRefOrNull = findSessionActor(id, channel) + if (sessionActorRefOrNull ne null) + sessionActorRefOrNull + else + { + // we dont have it in the session either, see if we have a factory for it + val actorFactoryOrNull = findActorFactory(id) + if (actorFactoryOrNull ne null) { + val actorRef = actorFactoryOrNull() + actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) + sessionActors.get(channel).put(id, actorRef) + actorRef + } + else + null + } + } + + + private def createClientManagedActor(actorInfo: ActorInfoProtocol): ActorRef = { + val uuid = actorInfo.getUuid + val id = actorInfo.getId + val timeout = actorInfo.getTimeout + val name = actorInfo.getTarget + + try { + if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( + "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + + log.info("Creating a new remote actor [%s:%s]", name, uuid) + val clazz = if (applicationLoader.isDefined) applicationLoader.get.loadClass(name) + else Class.forName(name) + val actorRef = Actor.actorOf(clazz.asInstanceOf[Class[_ <: Actor]]) + actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) + actorRef.id = id + actorRef.timeout = timeout + actorRef.remoteAddress = None + server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid + actorRef + } catch { + case e => + log.error(e, "Could not create remote actor instance") + server.notifyListeners(RemoteServerError(e, server)) + throw e + } + + } + /** * Creates a new instance of the actor with name, uuid and timeout specified as arguments. * @@ -662,72 +824,91 @@ class RemoteServerHandler( * * Does not start the actor. */ - private def createActor(actorInfo: ActorInfoProtocol): ActorRef = { + private def createActor(actorInfo: ActorInfoProtocol, channel: Channel): ActorRef = { val uuid = actorInfo.getUuid val id = actorInfo.getId - val name = actorInfo.getTarget - val timeout = actorInfo.getTimeout - val actorRefOrNull = findActorByIdOrUuid(id, uuidFrom(uuid.getHigh,uuid.getLow).toString) - if (actorRefOrNull eq null) { - try { - if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( - "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") - - log.info("Creating a new remote actor [%s:%s]", name, uuid) - val clazz = if (applicationLoader.isDefined) applicationLoader.get.loadClass(name) - else Class.forName(name) - val actorRef = Actor.actorOf(clazz.asInstanceOf[Class[_ <: Actor]]) - actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) - actorRef.id = id - actorRef.timeout = timeout - actorRef.remoteAddress = None - server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid - actorRef - } catch { - case e => - log.error(e, "Could not create remote actor instance") - server.notifyListeners(RemoteServerError(e, server)) - throw e - } - } else actorRefOrNull + if (actorRefOrNull ne null) + actorRefOrNull + else + { + // the actor has not been registered globally. See if we have it in the session + val sessionActorRefOrNull = createSessionActor(actorInfo, channel) + if (sessionActorRefOrNull ne null) + sessionActorRefOrNull + else // maybe it is a client managed actor + createClientManagedActor(actorInfo) + } } - private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = { + /** + * gets the actor from the session, or creates one if there is a factory for it + */ + private def createTypedSessionActor(actorInfo: ActorInfoProtocol, channel: Channel):AnyRef ={ + val id = actorInfo.getId + val sessionActorRefOrNull = findTypedSessionActor(id, channel) + if (sessionActorRefOrNull ne null) + sessionActorRefOrNull + else { + val actorFactoryOrNull = findTypedActorFactory(id) + if (actorFactoryOrNull ne null) { + val newInstance = actorFactoryOrNull() + typedSessionActors.get(channel).put(id, newInstance) + newInstance + } + else + null + } + + } + + private def createClientManagedTypedActor(actorInfo: ActorInfoProtocol) = { + val typedActorInfo = actorInfo.getTypedActorInfo + val interfaceClassname = typedActorInfo.getInterface + val targetClassname = actorInfo.getTarget + val uuid = actorInfo.getUuid + + try { + if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( + "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + + log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) + + val (interfaceClass, targetClass) = + if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), + applicationLoader.get.loadClass(targetClassname)) + else (Class.forName(interfaceClassname), Class.forName(targetClassname)) + + val newInstance = TypedActor.newInstance( + interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] + server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid + newInstance + } catch { + case e => + log.error(e, "Could not create remote typed actor instance") + server.notifyListeners(RemoteServerError(e, server)) + throw e + } + } + + private def createTypedActor(actorInfo: ActorInfoProtocol, channel: Channel): AnyRef = { val uuid = actorInfo.getUuid val id = actorInfo.getId val typedActorOrNull = findTypedActorByIdOrUuid(id, uuidFrom(uuid.getHigh,uuid.getLow).toString) - - if (typedActorOrNull eq null) { - val typedActorInfo = actorInfo.getTypedActorInfo - val interfaceClassname = typedActorInfo.getInterface - val targetClassname = actorInfo.getTarget - - try { - if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( - "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") - - log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) - - val (interfaceClass, targetClass) = - if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), - applicationLoader.get.loadClass(targetClassname)) - else (Class.forName(interfaceClassname), Class.forName(targetClassname)) - - val newInstance = TypedActor.newInstance( - interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] - server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid - newInstance - } catch { - case e => - log.error(e, "Could not create remote typed actor instance") - server.notifyListeners(RemoteServerError(e, server)) - throw e - } - } else typedActorOrNull + if (typedActorOrNull ne null) + typedActorOrNull + else + { + // the actor has not been registered globally. See if we have it in the session + val sessionActorRefOrNull = createTypedSessionActor(actorInfo, channel) + if (sessionActorRefOrNull ne null) + sessionActorRefOrNull + else // maybe it is a client managed actor + createClientManagedTypedActor(actorInfo) + } } private def createErrorReplyMessage(exception: Throwable, request: RemoteMessageProtocol, actorType: AkkaActorType): RemoteMessageProtocol = { diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java new file mode 100644 index 0000000000..8a6c2e6373 --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java @@ -0,0 +1,8 @@ +package akka.actor; + +public interface RemoteTypedSessionActor { + + public void login(String user); + public String getUser(); + public void doSomethingFunny() throws Exception; +} diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java new file mode 100644 index 0000000000..b4140f74ed --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java @@ -0,0 +1,49 @@ +package akka.actor.remote; + +import akka.actor.*; + +import java.util.Set; +import java.util.HashSet; + +import java.util.concurrent.CountDownLatch; + +public class RemoteTypedSessionActorImpl extends TypedActor implements RemoteTypedSessionActor { + + + private static Set