From 2478e5d6ad7f747fa5c0c14f93b06e5d2fbd0dd0 Mon Sep 17 00:00:00 2001 From: Michael Kober Date: Mon, 6 Sep 2010 09:35:02 +0200 Subject: [PATCH] fix: server initiated remote actors not found removed fixme comment --- .../src/main/scala/remote/RemoteServer.scala | 30 ++++++++++++------- .../ServerInitiatedRemoteActorSpec.scala | 16 ++++++++-- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index bae6b6c88c..5f24def4f5 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -234,9 +234,8 @@ class RemoteServer extends Logging with ListenerManagement { port = _port log.info("Starting remote server at [%s:%s]", hostname, port) RemoteServer.register(hostname, port, this) - val remoteActorSet = RemoteServer.actorsFor(RemoteServer.Address(hostname, port)) val pipelineFactory = new RemoteServerPipelineFactory( - name, openChannels, loader, remoteActorSet.actors, remoteActorSet.typedActors,this) + name, openChannels, loader, actors, typedActors, this) bootstrap.setPipelineFactory(pipelineFactory) bootstrap.setOption("child.tcpNoDelay", true) bootstrap.setOption("child.keepAlive", true) @@ -324,6 +323,13 @@ class RemoteServer extends Logging with ListenerManagement { protected override def manageLifeCycleOfListeners = false protected[akka] override def foreachListener(f: (ActorRef) => Unit): Unit = super.foreachListener(f) + + private def actors() : ConcurrentHashMap[String, ActorRef] = { + RemoteServer.actorsFor(RemoteServer.Address(hostname, port)).actors + } + private def typedActors() : ConcurrentHashMap[String, AnyRef] = { + RemoteServer.actorsFor(RemoteServer.Address(hostname, port)).typedActors + } } object RemoteServerSslContext { @@ -348,8 +354,8 @@ class RemoteServerPipelineFactory( val name: String, val openChannels: ChannelGroup, val loader: Option[ClassLoader], - val actors: JMap[String, ActorRef], - val typedActors: JMap[String, AnyRef], + val actors: (() => ConcurrentHashMap[String, ActorRef]), + val typedActors: (() => ConcurrentHashMap[String, AnyRef]), val server: RemoteServer) extends ChannelPipelineFactory { import RemoteServer._ @@ -373,7 +379,7 @@ class RemoteServerPipelineFactory( case _ => (join(), join()) } - val remoteServer = new RemoteServerHandler(name, openChannels, loader, actors, typedActors,server) + val remoteServer = new RemoteServerHandler(name, openChannels, loader, actors, typedActors, server) val stages = ssl ++ dec ++ join(lenDec, protobufDec) ++ enc ++ join(lenPrep, protobufEnc, remoteServer) new StaticChannelPipeline(stages: _*) } @@ -387,8 +393,8 @@ class RemoteServerHandler( val name: String, val openChannels: ChannelGroup, val applicationLoader: Option[ClassLoader], - val actors: JMap[String, ActorRef], - val typedActors: JMap[String, AnyRef], + val actors: (() => ConcurrentHashMap[String, ActorRef]), + val typedActors: (() => ConcurrentHashMap[String, AnyRef]), val server: RemoteServer) extends SimpleChannelUpstreamHandler with Logging { val AW_PROXY_PREFIX = "$$ProxiedByAW".intern @@ -539,7 +545,8 @@ class RemoteServerHandler( val name = actorInfo.getTarget val timeout = actorInfo.getTimeout - val actorRefOrNull = actors get uuid + val registeredActors = actors() + val actorRefOrNull = registeredActors get uuid if (actorRefOrNull eq null) { try { @@ -550,7 +557,7 @@ class RemoteServerHandler( actorRef.uuid = uuid actorRef.timeout = timeout actorRef.remoteAddress = None - actors.put(uuid, actorRef) + registeredActors.put(uuid, actorRef) actorRef } catch { case e => @@ -563,7 +570,8 @@ class RemoteServerHandler( private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = { val uuid = actorInfo.getUuid - val typedActorOrNull = typedActors get uuid + val registeredTypedActors = typedActors() + val typedActorOrNull = registeredTypedActors get uuid if (typedActorOrNull eq null) { val typedActorInfo = actorInfo.getTypedActorInfo @@ -580,7 +588,7 @@ class RemoteServerHandler( val newInstance = TypedActor.newInstance( interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef] - typedActors.put(uuid, newInstance) + registeredTypedActors.put(uuid, newInstance) newInstance } catch { case e => diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala index 59cfe3778d..8b1e0ef765 100644 --- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala +++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala @@ -5,8 +5,8 @@ import org.scalatest.junit.JUnitSuite import org.junit.{Test, Before, After} import se.scalablesolutions.akka.remote.{RemoteServer, RemoteClient} -import se.scalablesolutions.akka.actor.{ActorRef, Actor} -import Actor._ +import se.scalablesolutions.akka.actor.Actor._ +import se.scalablesolutions.akka.actor.{ActorRegistry, ActorRef, Actor} object ServerInitiatedRemoteActorSpec { val HOSTNAME = "localhost" @@ -132,5 +132,17 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite { } actor.stop } + + + @Test + def shouldNotRecreateRegisteredActor { + server.register(actorOf[RemoteActorSpecActorUnidirectional]) + val actor = RemoteClient.actorFor("se.scalablesolutions.akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorUnidirectional", HOSTNAME, PORT) + val numberOfActorsInRegistry = ActorRegistry.actors.length + val result = actor ! "OneWay" + assert(RemoteActorSpecActorUnidirectional.latch.await(1, TimeUnit.SECONDS)) + assert(numberOfActorsInRegistry === ActorRegistry.actors.length) + actor.stop + } }