From 3bd83e5388ed4306d33181956bad6825646d3885 Mon Sep 17 00:00:00 2001 From: Viktor Klang Date: Sat, 5 Mar 2011 14:55:58 +0100 Subject: [PATCH] Adding support for clean exit of remote server --- .../remote/netty/NettyRemoteSupport.scala | 67 +++++++++++++------ 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index db0fbe2937..6da6b56cec 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -13,10 +13,9 @@ import akka.serialization.RemoteActorSerialization._ import akka.japi.Creator import akka.config.Config._ import akka.remoteinterface._ -import akka.actor. {Index, ActorInitializationException, LocalActorRef, newUuid, ActorRegistry, Actor, RemoteActorRef, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage, uuidFrom, Uuid, Exit, LifeCycleMessage, ActorType => AkkaActorType} +import akka.actor.{EventHandler, Index, ActorInitializationException, LocalActorRef, newUuid, ActorRegistry, Actor, RemoteActorRef, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage, uuidFrom, Uuid, Exit, LifeCycleMessage, ActorType => AkkaActorType} import akka.AkkaException import akka.actor.Actor._ -import akka.actor.{EventHandler} import akka.util._ import akka.remote.{MessageSerializer, RemoteClientSettings, RemoteServerSettings} @@ -39,6 +38,20 @@ import scala.reflect.BeanProperty import java.lang.reflect.InvocationTargetException import java.util.concurrent.atomic. {AtomicReference, AtomicLong, AtomicBoolean} +object RemoteEncoder { + def encode(rmp: RemoteMessageProtocol): AkkaRemoteProtocol = { + val arp = AkkaRemoteProtocol.newBuilder + arp.setMessage(rmp) + arp.build + } + + def encode(rcp: RemoteControlProtocol): AkkaRemoteProtocol = { + val arp = AkkaRemoteProtocol.newBuilder + arp.setInstruction(rcp) + arp.build + } +} + trait NettyRemoteClientModule extends RemoteClientModule { self: ListenerManagement => private val remoteClients = new HashMap[Address, RemoteClient] private val remoteActors = new Index[Address, Uuid] @@ -197,7 +210,7 @@ abstract class RemoteClient private[akka] ( senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = { if (isRunning) { if (request.getOneWay) { - currentChannel.write(request).addListener(new ChannelFutureListener { + currentChannel.write(RemoteEncoder.encode(request)).addListener(new ChannelFutureListener { def operationComplete(future: ChannelFuture) { if (future.isCancelled) { //We don't care about that right now @@ -212,7 +225,7 @@ abstract class RemoteClient private[akka] ( else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) val futureUuid = uuidFrom(request.getUuid.getHigh, request.getUuid.getLow) futures.put(futureUuid, futureResult) //Add this prematurely, remove it if write fails - currentChannel.write(request).addListener(new ChannelFutureListener { + currentChannel.write(RemoteEncoder.encode(request)).addListener(new ChannelFutureListener { def operationComplete(future: ChannelFuture) { if (future.isCancelled) { futures.remove(futureUuid) //Clean this up @@ -328,9 +341,8 @@ class ActiveRemoteClient private[akka] ( true } else { val timeLeft = RECONNECTION_TIME_WINDOW - (System.currentTimeMillis - reconnectionTimeWindowStart) - if (timeLeft > 0) { - true - } else false + + timeLeft > 0 } } @@ -363,7 +375,7 @@ class ActiveRemoteClientPipelineFactory( val timeout = new ReadTimeoutHandler(timer, RemoteClientSettings.READ_TIMEOUT.toMillis.toInt) val lenDec = new LengthFieldBasedFrameDecoder(RemoteClientSettings.MESSAGE_FRAME_SIZE, 0, 4, 0, 4) val lenPrep = new LengthFieldPrepender(4) - val protobufDec = new ProtobufDecoder(RemoteMessageProtocol.getDefaultInstance) + val protobufDec = new ProtobufDecoder(AkkaRemoteProtocol.getDefaultInstance) val protobufEnc = new ProtobufEncoder val (enc, dec) = RemoteServerSettings.COMPRESSION_SCHEME match { case "zlib" => (join(new ZlibEncoder(RemoteServerSettings.ZLIB_COMPRESSION_LEVEL)), join(new ZlibDecoder)) @@ -400,7 +412,13 @@ class ActiveRemoteClientHandler( override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) { try { event.getMessage match { - case reply: RemoteMessageProtocol => + case arp: AkkaRemoteProtocol if arp.hasInstruction => + val rcp = arp.getInstruction + rcp.getCommandType match { + case CommandType.SHUTDOWN => spawn { client.shutdown } + } + case arp: AkkaRemoteProtocol if arp.hasMessage => + val reply = arp.getMessage val replyUuid = uuidFrom(reply.getActorInfo.getUuid.getHigh, reply.getActorInfo.getUuid.getLow) val future = futures.remove(replyUuid).asInstanceOf[CompletableFuture[Any]] @@ -423,7 +441,6 @@ class ActiveRemoteClientHandler( future.completeWithException(exception) } - case other => throw new RemoteClientException("Unknown message received in remote client handler: " + other, client.module, client.remoteAddress) } @@ -552,6 +569,14 @@ class NettyRemoteServer(serverModule: NettyRemoteServerModule, val host: String, def shutdown { try { + val shutdownSignal = { + val b = RemoteControlProtocol.newBuilder + if (RemoteClientSettings.SECURE_COOKIE.nonEmpty) + b.setCookie(RemoteClientSettings.SECURE_COOKIE.get) + b.setCommandType(CommandType.SHUTDOWN) + b.build + } + openChannels.write(RemoteEncoder.encode(shutdownSignal)).awaitUninterruptibly openChannels.disconnect openChannels.close.awaitUninterruptibly bootstrap.releaseExternalResources @@ -765,7 +790,7 @@ class RemoteServerPipelineFactory( val ssl = if(SECURE) join(new SslHandler(engine)) else join() val lenDec = new LengthFieldBasedFrameDecoder(MESSAGE_FRAME_SIZE, 0, 4, 0, 4) val lenPrep = new LengthFieldPrepender(4) - val protobufDec = new ProtobufDecoder(RemoteMessageProtocol.getDefaultInstance) + val protobufDec = new ProtobufDecoder(AkkaRemoteProtocol.getDefaultInstance) val protobufEnc = new ProtobufEncoder val (enc, dec) = COMPRESSION_SCHEME match { case "zlib" => (join(new ZlibEncoder(ZLIB_COMPRESSION_LEVEL)), join(new ZlibDecoder)) @@ -796,8 +821,8 @@ class RemoteServerHandler( val typedSessionActors = new ChannelLocal[ConcurrentHashMap[String, AnyRef]]() //Writes the specified message to the specified channel and propagates write errors to listeners - private def write(channel: Channel, message: AnyRef): Unit = { - channel.write(message).addListener( + private def write(channel: Channel, payload: AkkaRemoteProtocol): Unit = { + channel.write(payload).addListener( new ChannelFutureListener { def operationComplete(future: ChannelFuture): Unit = { if (future.isCancelled) { @@ -807,7 +832,7 @@ class RemoteServerHandler( case i: InetSocketAddress => Some(i) case _ => None } - server.notifyListeners(RemoteServerWriteFailed(message, future.getCause, server, socketAddress)) + server.notifyListeners(RemoteServerWriteFailed(payload, future.getCause, server, socketAddress)) } } }) @@ -871,7 +896,9 @@ class RemoteServerHandler( override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = event.getMessage match { case null => throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event) - case requestProtocol: RemoteMessageProtocol => + //case remoteProtocol: AkkaRemoteProtocol if remoteProtocol.hasInstruction => RemoteServer cannot receive control messages (yet) + case remoteProtocol: AkkaRemoteProtocol if remoteProtocol.hasMessage => + val requestProtocol = remoteProtocol.getMessage if (REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx) handleRemoteMessageProtocol(requestProtocol, event.getChannel) case _ => //ignore @@ -952,7 +979,7 @@ class RemoteServerHandler( // FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) - write(channel, messageBuilder.build) + write(channel, RemoteEncoder.encode(messageBuilder.build)) } } ) @@ -988,7 +1015,7 @@ class RemoteServerHandler( None) if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) - write(channel, messageBuilder.build) + write(channel, RemoteEncoder.encode(messageBuilder.build)) } catch { case e: Exception => EventHandler notifyListeners EventHandler.Error(e, this) @@ -1157,7 +1184,7 @@ class RemoteServerHandler( } } - private def createErrorReplyMessage(exception: Throwable, request: RemoteMessageProtocol, actorType: AkkaActorType): RemoteMessageProtocol = { + private def createErrorReplyMessage(exception: Throwable, request: RemoteMessageProtocol, actorType: AkkaActorType): AkkaRemoteProtocol = { val actorInfo = request.getActorInfo val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( None, @@ -1172,7 +1199,7 @@ class RemoteServerHandler( actorType, None) if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) - messageBuilder.build + RemoteEncoder.encode(messageBuilder.build) } private def authenticateRemoteClient(request: RemoteMessageProtocol, ctx: ChannelHandlerContext) = { @@ -1212,4 +1239,4 @@ class DefaultDisposableChannelGroup(name: String) extends DefaultChannelGroup(na throw new IllegalStateException("ChannelGroup already closed, cannot add new channel") } } -} +} \ No newline at end of file