diff --git a/akka-remote/src/main/scala/akka/remote/netty/Server.scala b/akka-remote/src/main/scala/akka/remote/netty/Server.scala index 0bafb1c712..1f18b27c8c 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Server.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Server.scala @@ -10,7 +10,6 @@ import org.jboss.netty.bootstrap.ServerBootstrap import org.jboss.netty.channel.ChannelHandler.Sharable import org.jboss.netty.channel.group.ChannelGroup import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory -import org.jboss.netty.channel.{ StaticChannelPipeline, SimpleChannelUpstreamHandler, MessageEvent, ExceptionEvent, ChannelStateEvent, ChannelPipelineFactory, ChannelPipeline, ChannelHandlerContext, ChannelHandler, Channel } import org.jboss.netty.handler.codec.frame.{ LengthFieldPrepender, LengthFieldBasedFrameDecoder } import org.jboss.netty.handler.execution.ExecutionHandler import akka.event.Logging @@ -19,8 +18,7 @@ import akka.remote.{ RemoteServerShutdown, RemoteServerError, RemoteServerClient import akka.actor.Address import java.net.InetAddress import akka.actor.ActorSystemImpl -import org.jboss.netty.channel.ChannelLocal -import org.jboss.netty.channel.ChannelEvent +import org.jboss.netty.channel._ class NettyRemoteServer(val netty: NettyRemoteTransport) { @@ -135,6 +133,10 @@ class RemoteServerHandler( val openChannels: ChannelGroup, val netty: NettyRemoteTransport) extends SimpleChannelUpstreamHandler { + val channelAddress = new ChannelLocal[Option[Address]](false) { + override def initialValue(channel: Channel) = None + } + import netty.settings private var addressToSet = true @@ -154,23 +156,20 @@ class RemoteServerHandler( */ override def channelOpen(ctx: ChannelHandlerContext, event: ChannelStateEvent) = openChannels.add(ctx.getChannel) - override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - val clientAddress = getClientAddress(ctx.getChannel) - netty.notifyListeners(RemoteServerClientConnected(netty, clientAddress)) - } + // TODO might want to log or otherwise signal that a TCP connection has been established here. + override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = () override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - val clientAddress = getClientAddress(ctx.getChannel) - netty.notifyListeners(RemoteServerClientDisconnected(netty, clientAddress)) + netty.notifyListeners(RemoteServerClientDisconnected(netty, channelAddress.get(ctx.getChannel))) } - override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = getClientAddress(ctx.getChannel) match { - case s @ Some(address) ⇒ - if (settings.UsePassiveConnections) - netty.unbindClient(address) - netty.notifyListeners(RemoteServerClientClosed(netty, s)) - case None ⇒ - netty.notifyListeners(RemoteServerClientClosed(netty, None)) + override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { + val address = channelAddress.get(ctx.getChannel) + if (address.isDefined && settings.UsePassiveConnections) + netty.unbindClient(address.get) + + netty.notifyListeners(RemoteServerClientClosed(netty, address)) + channelAddress.remove(ctx.getChannel) } override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = try { @@ -181,11 +180,16 @@ class RemoteServerHandler( case remote: AkkaRemoteProtocol if remote.hasInstruction ⇒ val instruction = remote.getInstruction instruction.getCommandType match { - case CommandType.CONNECT if settings.UsePassiveConnections ⇒ + case CommandType.CONNECT ⇒ val origin = instruction.getOrigin val inbound = Address("akka", origin.getSystem, origin.getHostname, origin.getPort) - val client = new PassiveRemoteClient(event.getChannel, netty, inbound) - netty.bindClient(inbound, client) + channelAddress.set(event.getChannel, Option(inbound)) + + //If we want to reuse the inbound connections as outbound we need to get busy + if (settings.UsePassiveConnections) + netty.bindClient(inbound, new PassiveRemoteClient(event.getChannel, netty, inbound)) + + netty.notifyListeners(RemoteServerClientConnected(netty, Option(inbound))) case CommandType.SHUTDOWN ⇒ //Will be unbound in channelClosed case CommandType.HEARTBEAT ⇒ //Other guy is still alive case _ ⇒ //Unknown command @@ -200,11 +204,5 @@ class RemoteServerHandler( netty.notifyListeners(RemoteServerError(event.getCause, netty)) event.getChannel.close() } - - private def getClientAddress(c: Channel): Option[Address] = - c.getRemoteAddress match { - case inet: InetSocketAddress ⇒ Some(Address("akka", "unknown(yet)", inet.getAddress.toString, inet.getPort)) - case _ ⇒ None - } }