diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index 2a92fff853..5ba58567ef 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -458,7 +458,7 @@ class RemoteServerHandler( val AW_PROXY_PREFIX = "$$ProxiedByAW".intern val CHANNEL_INIT = "channel-init".intern - val sessionActors = new ChannelLocal(); + val sessionActors = new ChannelLocal[Map[String, ActorRef]](); applicationLoader.foreach(MessageSerializer.setClassLoader(_)) @@ -470,6 +470,7 @@ class RemoteServerHandler( override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) + sessionActors.set(event.getChannel(), Map[String, ActorRef]()); log.debug("Remote client [%s] connected to [%s]", clientAddress, server.name) if (RemoteServer.SECURE) { val sslHandler: SslHandler = ctx.getPipeline.get(classOf[SslHandler]) @@ -489,6 +490,11 @@ class RemoteServerHandler( override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val clientAddress = getClientAddress(ctx) log.debug("Remote client [%s] disconnected from [%s]", clientAddress, server.name) + // stop all session actors + for ((id, actorRef) <- sessionActors.get(event.getChannel())) { + actorRef.stop() + } + sessionActors.remove(event.getChannel()); server.notifyListeners(RemoteServerClientDisconnected(server, clientAddress)) } @@ -543,7 +549,7 @@ class RemoteServerHandler( val actorRef = try { - createActor(actorInfo).start + createActor(actorInfo, channel).start } catch { case e: SecurityException => channel.write(createErrorReplyMessage(e, request, AkkaActorType.ScalaActor)) @@ -664,6 +670,13 @@ class RemoteServerHandler( server.actorsByUuid.get(uuid) } + private def findActorFactory(id: String) : () => ActorRef = { + server.actorsFactories.get(id) + } + private def findSessionActor(id: String, channel: Channel) : ActorRef = { + sessionActors.get(channel).getOrElse(id, null) + } + private def findTypedActorById(id: String) : AnyRef = { server.typedActors.get(id) } @@ -693,7 +706,7 @@ class RemoteServerHandler( * * Does not start the actor. */ - private def createActor(actorInfo: ActorInfoProtocol): ActorRef = { + private def createActor(actorInfo: ActorInfoProtocol, channel: Channel): ActorRef = { val uuid = actorInfo.getUuid val id = actorInfo.getId @@ -702,28 +715,47 @@ class RemoteServerHandler( val actorRefOrNull = findActorByIdOrUuid(id, uuidFrom(uuid.getHigh,uuid.getLow).toString) - if (actorRefOrNull eq null) { - try { - if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( - "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + if (actorRefOrNull ne null) + return actorRefOrNull - log.info("Creating a new remote actor [%s:%s]", name, uuid) - val clazz = if (applicationLoader.isDefined) applicationLoader.get.loadClass(name) - else Class.forName(name) - val actorRef = Actor.actorOf(clazz.asInstanceOf[Class[_ <: Actor]]) - actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) - actorRef.id = id - actorRef.timeout = timeout - actorRef.remoteAddress = None - server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid - actorRef - } catch { - case e => - log.error(e, "Could not create remote actor instance") - server.notifyListeners(RemoteServerError(e, server)) - throw e - } - } else actorRefOrNull + + // the actor has not been registered globally. See if we have it in the session + + val sessionActorRefOrNull = findSessionActor(id, channel); + if (sessionActorRefOrNull ne null) + return sessionActorRefOrNull + + // we dont have it in the session either, see if we have a factory for it + val actorFactoryOrNull = findActorFactory(id) + if (actorFactoryOrNull ne null) { + val actorRef = actorFactoryOrNull(); + actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) + sessionActors.get(channel).put(id, actorRef); + server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid + return actorRef + } + + // None of the above, so treat it as a client managed remote actor + try { + if (RemoteServer.UNTRUSTED_MODE) throw new SecurityException( + "Remote server is operating is untrusted mode, can not create remote actors on behalf of the remote client") + + log.info("Creating a new remote actor [%s:%s]", name, uuid) + val clazz = if (applicationLoader.isDefined) applicationLoader.get.loadClass(name) + else Class.forName(name) + val actorRef = Actor.actorOf(clazz.asInstanceOf[Class[_ <: Actor]]) + actorRef.uuid = uuidFrom(uuid.getHigh,uuid.getLow) + actorRef.id = id + actorRef.timeout = timeout + actorRef.remoteAddress = None + server.actorsByUuid.put(actorRef.uuid.toString, actorRef) // register by uuid + actorRef + } catch { + case e => + log.error(e, "Could not create remote actor instance") + server.notifyListeners(RemoteServerError(e, server)) + throw e + } } private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = { diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala index ff8793218e..f0e2123310 100644 --- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala @@ -7,12 +7,14 @@ import org.junit.{Test, Before, After} import akka.remote.{RemoteServer, RemoteClient} import akka.actor.Actor._ import akka.actor.{ActorRegistry, ActorRef, Actor} +import scala.collection.mutable.Set object ServerInitiatedRemoteActorSpec { val HOSTNAME = "localhost" val PORT = 9990 var server: RemoteServer = null + case class Send(actor: ActorRef) object RemoteActorSpecActorUnidirectional { @@ -38,16 +40,29 @@ object ServerInitiatedRemoteActorSpec { case class Login(user:String); case class GetUser(); - + case class DoSomethingWeird(); + + val instantiatedSessionActors= Set[ActorRef](); + class RemoteStatefullSessionActorSpec extends Actor { - var user : String= _; + var user : String= "anonymous"; + override def preStart = { + instantiatedSessionActors += self; + } + + override def postStop = { + instantiatedSessionActors -= self; + } + def receive = { case Login(user) => this.user = user; case GetUser() => self.reply(this.user) + case DoSomethingWeird() => + throw new Exception("Bad boy") } } @@ -122,25 +137,82 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { @Test def shouldKeepSessionInformation { + + //RemoteClient.clientFor(HOSTNAME, PORT).connect + val session1 = RemoteClient.actorFor( "statefull-session-actor", 5000L, HOSTNAME, PORT) + + + val default1 = session1 !! GetUser(); + assert("anonymous" === default1.get.asInstanceOf[String]) + + session1 ! Login("session[1]"); + + val result1 = session1 !! GetUser(); + assert("session[1]" === result1.get.asInstanceOf[String]) + + session1.stop() + + RemoteClient.shutdownAll + + //RemoteClient.clientFor(HOSTNAME, PORT).connect + val session2 = RemoteClient.actorFor( "statefull-session-actor", 5000L, HOSTNAME, PORT) - session1 ! Login("session1"); - session2 ! Login("session2"); + // since this is a new session, the server should reset the state + val default2 = session2 !! GetUser(); + assert("anonymous" === default2.get.asInstanceOf[String]) - val result1 = session1 !! GetUser(); - assert("session1" === result1.get.asInstanceOf[String]) - val result2 = session2 !! GetUser(); - assert("session2" === result2.get.asInstanceOf[String]) - - session1.stop() session2.stop() + + RemoteClient.shutdownAll + } + + @Test + def shouldStopActorOnDisconnect{ + + + val session1 = RemoteClient.actorFor( + "statefull-session-actor", + 5000L, + HOSTNAME, PORT) + + + val default1 = session1 !! GetUser(); + assert("anonymous" === default1.get.asInstanceOf[String]) + + assert(instantiatedSessionActors.size == 1); + + RemoteClient.shutdownAll + Thread.sleep(1000) + assert(instantiatedSessionActors.size == 0); + + } + + @Test + def shouldStopActorOnError{ + + + val session1 = RemoteClient.actorFor( + "statefull-session-actor", + 5000L, + HOSTNAME, PORT) + + + session1 ! DoSomethingWeird(); + session1.stop() + + RemoteClient.shutdownAll + Thread.sleep(1000) + + assert(instantiatedSessionActors.size == 0); + } @Test