diff --git a/akka-actor/src/main/scala/actor/ActorRegistry.scala b/akka-actor/src/main/scala/actor/ActorRegistry.scala index 0b946df99f..f8b46fd3c4 100644 --- a/akka-actor/src/main/scala/actor/ActorRegistry.scala +++ b/akka-actor/src/main/scala/actor/ActorRegistry.scala @@ -304,6 +304,7 @@ object ActorRegistry extends ListenerManagement { private[akka] def actorsFactories(address: Address) = actorsFor(address).actorsFactories private[akka] def typedActors(address: Address) = actorsFor(address).typedActors private[akka] def typedActorsByUuid(address: Address) = actorsFor(address).typedActorsByUuid + private[akka] def typedActorsFactories(address: Address) = actorsFor(address).typedActorsFactories private[akka] class RemoteActorSet { private[ActorRegistry] val actors = new ConcurrentHashMap[String, ActorRef] @@ -311,6 +312,7 @@ object ActorRegistry extends ListenerManagement { private[ActorRegistry] val actorsFactories = new ConcurrentHashMap[String, () => ActorRef] private[ActorRegistry] val typedActors = new ConcurrentHashMap[String, AnyRef] private[ActorRegistry] val typedActorsByUuid = new ConcurrentHashMap[String, AnyRef] + private[ActorRegistry] val typedActorsFactories = new ConcurrentHashMap[String, () => AnyRef] } } diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index 5ba58567ef..9afef140b1 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -284,6 +284,21 @@ class RemoteServer extends Logging with ListenerManagement { else registerTypedActor(id, typedActor, typedActors) } + /** + * Register typed actor by interface name. + */ + def registerTypedPerSessionActor(intfClass: Class[_], factory: => AnyRef) : Unit = registerTypedActor(intfClass.getName, factory) + + /** + * Register remote typed actor by a specific id. + * @param id custom actor id + * @param typedActor typed actor to register + */ + def registerTypedPerSessionActor(id: String, factory: => AnyRef): Unit = synchronized { + log.debug("Registering server side typed remote session actor with id [%s]", id) + registerTypedPerSessionActor(id, () => factory, typedActorsFactories) + } + /** * Register Remote Actor by the Actor's 'id' field. It starts the Actor if it is not started already. */ @@ -307,8 +322,7 @@ class RemoteServer extends Logging with ListenerManagement { */ def registerPerSession(id: String, factory: => ActorRef): Unit = synchronized { log.debug("Registering server side remote session actor with id [%s]", id) - if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), factory, actorsByUuid) - else registerPerSession(id, () => factory, actorsFactories) + registerPerSession(id, () => factory, actorsFactories) } private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) { @@ -328,6 +342,12 @@ class RemoteServer extends Logging with ListenerManagement { if (_isRunning && !registry.contains(id)) registry.put(id, typedActor) } + private def registerTypedPerSessionActor[Key](id: Key, factory: () => AnyRef, registry: ConcurrentHashMap[Key,() => AnyRef]) { + if (_isRunning && !registry.contains(id)) { + registry.put(id, factory) + } + } + /** * Unregister Remote Actor that is registered using its 'id' field (not custom ID). */ @@ -381,6 +401,17 @@ class RemoteServer extends Logging with ListenerManagement { } } + /** + * Unregister Remote Typed Actor by specific 'id'. + *

+ * NOTE: You need to call this method if you have registered an actor by a custom ID. + */ + def unregisterTypedPerSessionActor(id: String):Unit = synchronized { + if (_isRunning) { + typedActorsFactories.remove(id) + } + } + protected override def manageLifeCycleOfListeners = false protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message) @@ -391,6 +422,7 @@ class RemoteServer extends Logging with ListenerManagement { private[akka] def actorsFactories = ActorRegistry.actorsFactories(address) private[akka] def typedActors = ActorRegistry.typedActors(address) private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address) + private[akka] def typedActorsFactories = ActorRegistry.typedActorsFactories(address) } object RemoteServerSslContext { @@ -459,6 +491,7 @@ class RemoteServerHandler( val CHANNEL_INIT = "channel-init".intern val sessionActors = new ChannelLocal[Map[String, ActorRef]](); + val typedSessionActors = new ChannelLocal[Map[String, AnyRef]](); applicationLoader.foreach(MessageSerializer.setClassLoader(_)) @@ -471,6 +504,7 @@ class RemoteServerHandler( override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) sessionActors.set(event.getChannel(), Map[String, ActorRef]()); + typedSessionActors.set(event.getChannel(), Map[String, AnyRef]()); log.debug("Remote client [%s] connected to [%s]", clientAddress, server.name) if (RemoteServer.SECURE) { val sslHandler: SslHandler = ctx.getPipeline.get(classOf[SslHandler]) @@ -495,6 +529,11 @@ class RemoteServerHandler( actorRef.stop() } sessionActors.remove(event.getChannel()); + for ((id, actorRef) <- typedSessionActors.get(event.getChannel())) { + TypedActor.stop(actorRef) + } + typedSessionActors.remove(event.getChannel()); + server.notifyListeners(RemoteServerClientDisconnected(server, clientAddress)) } @@ -623,7 +662,7 @@ class RemoteServerHandler( val typedActorInfo = actorInfo.getTypedActorInfo log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface) - val typedActor = createTypedActor(actorInfo) + val typedActor = createTypedActor(actorInfo, channel) val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList val argClasses = args.map(_.getClass) @@ -673,6 +712,7 @@ class RemoteServerHandler( private def findActorFactory(id: String) : () => ActorRef = { server.actorsFactories.get(id) } + private def findSessionActor(id: String, channel: Channel) : ActorRef = { sessionActors.get(channel).getOrElse(id, null) } @@ -681,6 +721,14 @@ class RemoteServerHandler( server.typedActors.get(id) } + private def findTypedActorFactory(id: String) : () => AnyRef = { + server.typedActorsFactories.get(id) + } + + private def findTypedSessionActor(id: String, channel: Channel) : AnyRef = { + typedSessionActors.get(channel).getOrElse(id, null) + } + private def findTypedActorByUuid(uuid: String) : AnyRef = { server.typedActorsByUuid.get(uuid) } @@ -731,7 +779,6 @@ class RemoteServerHandler( val actorRef = actorFactoryOrNull(); actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) sessionActors.get(channel).put(id, actorRef); - server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid return actorRef } @@ -758,39 +805,55 @@ class RemoteServerHandler( } } - private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = { + private def createTypedActor(actorInfo: ActorInfoProtocol, channel: Channel): AnyRef = { val uuid = actorInfo.getUuid val id = actorInfo.getId val typedActorOrNull = findTypedActorByIdOrUuid(id, uuidFrom(uuid.getHigh,uuid.getLow).toString) + if (typedActorOrNull ne null) + return typedActorOrNull; - if (typedActorOrNull eq null) { - val typedActorInfo = actorInfo.getTypedActorInfo - val interfaceClassname = typedActorInfo.getInterface - val targetClassname = actorInfo.getTarget + // the actor has not been registered globally. See if we have it in the session - try { - if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( - "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + val sessionActorRefOrNull = findTypedSessionActor(id, channel); + if (sessionActorRefOrNull ne null) + return sessionActorRefOrNull - log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) + // we dont have it in the session either, see if we have a factory for it + val actorFactoryOrNull = findTypedActorFactory(id) + if (actorFactoryOrNull ne null) { + val newInstance = actorFactoryOrNull(); + typedSessionActors.get(channel).put(id, newInstance); + return newInstance + } - val (interfaceClass, targetClass) = - if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), - applicationLoader.get.loadClass(targetClassname)) - else (Class.forName(interfaceClassname), Class.forName(targetClassname)) + // None of the above, so treat it as a client managed remote actor - val newInstance = TypedActor.newInstance( - interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] - server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid - newInstance - } catch { - case e => - log.error(e, "Could not create remote typed actor instance") - server.notifyListeners(RemoteServerError(e, server)) - throw e - } - } else typedActorOrNull + val typedActorInfo = actorInfo.getTypedActorInfo + val interfaceClassname = typedActorInfo.getInterface + val targetClassname = actorInfo.getTarget + + try { + if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( + "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + + log.info("Creating a new remote typed actor:\n\t[%s :: %s]", interfaceClassname, targetClassname) + + val (interfaceClass, targetClass) = + if (applicationLoader.isDefined) (applicationLoader.get.loadClass(interfaceClassname), + applicationLoader.get.loadClass(targetClassname)) + else (Class.forName(interfaceClassname), Class.forName(targetClassname)) + + val newInstance = TypedActor.newInstance( + interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] + server.typedActors.put(uuidFrom(uuid.getHigh,uuid.getLow).toString, newInstance) // register by uuid + newInstance + } catch { + case e => + log.error(e, "Could not create remote typed actor instance") + server.notifyListeners(RemoteServerError(e, server)) + throw e + } } private def createErrorReplyMessage(exception: Throwable, request: RemoteMessageProtocol, actorType: AkkaActorType): RemoteMessageProtocol = { diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java new file mode 100644 index 0000000000..8a6c2e6373 --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActor.java @@ -0,0 +1,8 @@ +package akka.actor; + +public interface RemoteTypedSessionActor { + + public void login(String user); + public String getUser(); + public void doSomethingFunny() throws Exception; +} diff --git a/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java new file mode 100644 index 0000000000..b4140f74ed --- /dev/null +++ b/akka-remote/src/test/java/akka/actor/RemoteTypedSessionActorImpl.java @@ -0,0 +1,49 @@ +package akka.actor.remote; + +import akka.actor.*; + +import java.util.Set; +import java.util.HashSet; + +import java.util.concurrent.CountDownLatch; + +public class RemoteTypedSessionActorImpl extends TypedActor implements RemoteTypedSessionActor { + + + private static Set instantiatedSessionActors = new HashSet(); + + public static Set getInstances() { + return instantiatedSessionActors; + } + + @Override + public void preStart() { + instantiatedSessionActors.add(this); + } + + @Override + public void postStop() { + instantiatedSessionActors.remove(this); + } + + + private String user="anonymous"; + + @Override + public void login(String user) { + this.user = user; + } + + @Override + public String getUser() + { + return this.user; + } + + @Override + public void doSomethingFunny() throws Exception + { + throw new Exception("Bad boy"); + } + +} diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala index f0e2123310..5ffbecb612 100644 --- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala @@ -40,7 +40,7 @@ object ServerInitiatedRemoteActorSpec { case class Login(user:String); case class GetUser(); - case class DoSomethingWeird(); + case class DoSomethingFunny(); val instantiatedSessionActors= Set[ActorRef](); @@ -61,7 +61,7 @@ object ServerInitiatedRemoteActorSpec { this.user = user; case GetUser() => self.reply(this.user) - case DoSomethingWeird() => + case DoSomethingFunny() => throw new Exception("Bad boy") } } @@ -94,7 +94,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { server.register(actorOf[RemoteActorSpecActorUnidirectional]) server.register(actorOf[RemoteActorSpecActorBidirectional]) server.register(actorOf[RemoteActorSpecActorAsyncSender]) - server.registerPerSession("statefull-session-actor", actorOf[RemoteStatefullSessionActorSpec]) + server.registerPerSession("untyped-session-actor-service", actorOf[RemoteStatefullSessionActorSpec]) Thread.sleep(1000) } @@ -141,7 +141,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { //RemoteClient.clientFor(HOSTNAME, PORT).connect val session1 = RemoteClient.actorFor( - "statefull-session-actor", + "untyped-session-actor-service", 5000L, HOSTNAME, PORT) @@ -161,7 +161,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { //RemoteClient.clientFor(HOSTNAME, PORT).connect val session2 = RemoteClient.actorFor( - "statefull-session-actor", + "untyped-session-actor-service", 5000L, HOSTNAME, PORT) @@ -179,7 +179,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { val session1 = RemoteClient.actorFor( - "statefull-session-actor", + "untyped-session-actor-service", 5000L, HOSTNAME, PORT) @@ -200,12 +200,12 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { val session1 = RemoteClient.actorFor( - "statefull-session-actor", + "untyped-session-actor-service", 5000L, HOSTNAME, PORT) - session1 ! DoSomethingWeird(); + session1 ! DoSomethingFunny(); session1.stop() RemoteClient.shutdownAll diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedSessionActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedSessionActorSpec.scala new file mode 100644 index 0000000000..e766c24d67 --- /dev/null +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedSessionActorSpec.scala @@ -0,0 +1,110 @@ +/** + * Copyright (C) 2009-2010 Scalable Solutions AB + */ + +package akka.actor.remote + +import org.scalatest._ +import org.scalatest.matchers.ShouldMatchers +import org.scalatest.BeforeAndAfterAll +import org.scalatest.junit.JUnitRunner +import org.junit.runner.RunWith + +import java.util.concurrent.TimeUnit + +import akka.remote.{RemoteServer, RemoteClient} +import akka.actor._ +import RemoteTypedActorLog._ + +object ServerInitiatedRemoteTypedSessionActorSpec { + val HOSTNAME = "localhost" + val PORT = 9990 + var server: RemoteServer = null +} + +@RunWith(classOf[JUnitRunner]) +class ServerInitiatedRemoteTypedSessionActorSpec extends + FlatSpec with + ShouldMatchers with + BeforeAndAfterAll { + import ServerInitiatedRemoteTypedActorSpec._ + + private val unit = TimeUnit.MILLISECONDS + + + override def beforeAll = { + server = new RemoteServer() + server.start(HOSTNAME, PORT) + + server.registerTypedPerSessionActor("typed-session-actor-service", + TypedActor.newInstance(classOf[RemoteTypedSessionActor], classOf[RemoteTypedSessionActorImpl], 1000)) + + Thread.sleep(1000) + } + + // make sure the servers shutdown cleanly after the test has finished + override def afterAll = { + try { + server.shutdown + RemoteClient.shutdownAll + Thread.sleep(1000) + } catch { + case e => () + } + } + + "A remote session Actor" should "create a new session actor per connection" in { + clearMessageLogs + + val session1 = RemoteClient.typedActorFor(classOf[RemoteTypedSessionActor], "typed-session-actor-service", 5000L, HOSTNAME, PORT) + + session1.getUser() should equal ("anonymous"); + session1.login("session[1]"); + session1.getUser() should equal ("session[1]"); + + RemoteClient.shutdownAll + + val session2 = RemoteClient.typedActorFor(classOf[RemoteTypedSessionActor], "typed-session-actor-service", 5000L, HOSTNAME, PORT) + + session2.getUser() should equal ("anonymous"); + + RemoteClient.shutdownAll + + } + + it should "stop the actor when the client disconnects" in { + + val session1 = RemoteClient.typedActorFor(classOf[RemoteTypedSessionActor], "typed-session-actor-service", 5000L, HOSTNAME, PORT) + + session1.getUser() should equal ("anonymous"); + + RemoteTypedSessionActorImpl.getInstances() should have size (1); + RemoteClient.shutdownAll + Thread.sleep(1000) + RemoteTypedSessionActorImpl.getInstances() should have size (0); + + } + + it should "stop the actor when there is an error" in { + + val session1 = RemoteClient.typedActorFor(classOf[RemoteTypedSessionActor], "typed-session-actor-service", 5000L, HOSTNAME, PORT) + + session1.doSomethingFunny(); + + RemoteClient.shutdownAll + Thread.sleep(1000) + RemoteTypedSessionActorImpl.getInstances() should have size (0); + + } + + + it should "be able to unregister" in { + server.registerTypedPerSessionActor("my-service-1",TypedActor.newInstance(classOf[RemoteTypedSessionActor], classOf[RemoteTypedSessionActorImpl], 1000)) + + server.typedActorsFactories.get("my-service-1") should not be (null) + server.unregisterTypedPerSessionActor("my-service-1") + server.typedActorsFactories.get("my-service-1") should be (null) + } + +} +