diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index 19b18fe529..f35fbdd5f6 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -20,7 +20,7 @@ import akka.util._ import akka.remote.{MessageSerializer, RemoteClientSettings, RemoteServerSettings} import org.jboss.netty.channel._ -import org.jboss.netty.channel.group.{DefaultChannelGroup,ChannelGroup} +import org.jboss.netty.channel.group.{DefaultChannelGroup,ChannelGroup,ChannelGroupFuture} import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.bootstrap.{ServerBootstrap,ClientBootstrap} @@ -264,7 +264,7 @@ class ActiveRemoteClient private[akka] ( def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = { runSwitch switchOn { - openChannels = new DefaultChannelGroup(classOf[RemoteClient].getName) + openChannels = new DefaultDisposableChannelGroup(classOf[RemoteClient].getName) timer = new HashedWheelTimer bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool)) @@ -543,7 +543,7 @@ class NettyRemoteServer(serverModule: NettyRemoteServerModule, val host: String, private val bootstrap = new ServerBootstrap(factory) // group of open channels, used for clean-up - private val openChannels: ChannelGroup = new DefaultChannelGroup("akka-remote-server") + private val openChannels: ChannelGroup = new DefaultDisposableChannelGroup("akka-remote-server") val pipelineFactory = new RemoteServerPipelineFactory(name, openChannels, loader, serverModule) bootstrap.setPipelineFactory(pipelineFactory) @@ -1220,3 +1220,25 @@ class RemoteServerHandler( protected def parseUuid(protocol: UuidProtocol): Uuid = uuidFrom(protocol.getHigh,protocol.getLow) } + +class DefaultDisposableChannelGroup(name: String) extends DefaultChannelGroup(name) { + protected val guard = new ReadWriteGuard + protected val open = new AtomicBoolean(true) + + override def add(channel: Channel): Boolean = guard withReadGuard { + if(open.get) { + super.add(channel) + } else { + channel.close + false + } + } + + override def close(): ChannelGroupFuture = guard withWriteGuard { + if (open.getAndSet(false)) { + super.close + } else { + throw new IllegalStateException("ChannelGroup already closed, cannot add new channel") + } + } +} \ No newline at end of file