diff --git a/akka-remote/src/main/java/akka/remote/protocol/RemoteProtocol.java b/akka-remote/src/main/java/akka/remote/protocol/RemoteProtocol.java index 4d2bdfdce1..f83430b478 100644 --- a/akka-remote/src/main/java/akka/remote/protocol/RemoteProtocol.java +++ b/akka-remote/src/main/java/akka/remote/protocol/RemoteProtocol.java @@ -10,7 +10,8 @@ public final class RemoteProtocol { } public enum CommandType implements com.google.protobuf.ProtocolMessageEnum { - SHUTDOWN(0, 1), + CONNECT(0, 1), + SHUTDOWN(1, 2), ; @@ -18,7 +19,8 @@ public final class RemoteProtocol { public static CommandType valueOf(int value) { switch (value) { - case 1: return SHUTDOWN; + case 1: return CONNECT; + case 2: return SHUTDOWN; default: return null; } } @@ -49,7 +51,7 @@ public final class RemoteProtocol { } private static final CommandType[] VALUES = { - SHUTDOWN, + CONNECT, SHUTDOWN, }; public static CommandType valueOf( com.google.protobuf.Descriptors.EnumValueDescriptor desc) { @@ -680,13 +682,6 @@ public final class RemoteProtocol { return metadata_.get(index); } - // optional string cookie = 9; - public static final int COOKIE_FIELD_NUMBER = 9; - private boolean hasCookie; - private java.lang.String cookie_ = ""; - public boolean hasCookie() { return hasCookie; } - public java.lang.String getCookie() { return cookie_; } - private void initFields() { uuid_ = akka.remote.protocol.RemoteProtocol.UuidProtocol.getDefaultInstance(); actorInfo_ = akka.remote.protocol.RemoteProtocol.ActorInfoProtocol.getDefaultInstance(); @@ -746,9 +741,6 @@ public final class RemoteProtocol { for (akka.remote.protocol.RemoteProtocol.MetadataEntryProtocol element : getMetadataList()) { output.writeMessage(8, element); } - if (hasCookie()) { - output.writeString(9, getCookie()); - } getUnknownFields().writeTo(output); } @@ -790,10 +782,6 @@ public final class RemoteProtocol { size += com.google.protobuf.CodedOutputStream .computeMessageSize(8, element); } - if (hasCookie()) { - size += com.google.protobuf.CodedOutputStream - .computeStringSize(9, getCookie()); - } size += getUnknownFields().getSerializedSize(); memoizedSerializedSize = size; return size; @@ -983,9 +971,6 @@ public final class RemoteProtocol { } result.metadata_.addAll(other.metadata_); } - if (other.hasCookie()) { - setCookie(other.getCookie()); - } this.mergeUnknownFields(other.getUnknownFields()); return this; } @@ -1075,10 +1060,6 @@ public final class RemoteProtocol { addMetadata(subBuilder.buildPartial()); break; } - case 74: { - setCookie(input.readString()); - break; - } } } } @@ -1375,27 +1356,6 @@ public final class RemoteProtocol { return this; } - // optional string cookie = 9; - public boolean hasCookie() { - return result.hasCookie(); - } - public java.lang.String getCookie() { - return result.getCookie(); - } - public Builder setCookie(java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - result.hasCookie = true; - result.cookie_ = value; - return this; - } - public Builder clearCookie() { - result.hasCookie = false; - result.cookie_ = getDefaultInstance().getCookie(); - return this; - } - // @@protoc_insertion_point(builder_scope:RemoteMessageProtocol) } @@ -1450,7 +1410,7 @@ public final class RemoteProtocol { public akka.remote.protocol.RemoteProtocol.CommandType getCommandType() { return commandType_; } private void initFields() { - commandType_ = akka.remote.protocol.RemoteProtocol.CommandType.SHUTDOWN; + commandType_ = akka.remote.protocol.RemoteProtocol.CommandType.CONNECT; } public final boolean isInitialized() { if (!hasCommandType) return false; @@ -1729,7 +1689,7 @@ public final class RemoteProtocol { } public Builder clearCommandType() { result.hasCommandType = false; - result.commandType_ = akka.remote.protocol.RemoteProtocol.CommandType.SHUTDOWN; + result.commandType_ = akka.remote.protocol.RemoteProtocol.CommandType.CONNECT; return this; } @@ -5710,46 +5670,46 @@ public final class RemoteProtocol { "\n\024RemoteProtocol.proto\"j\n\022AkkaRemoteProt" + "ocol\022\'\n\007message\030\001 \001(\0132\026.RemoteMessagePro" + "tocol\022+\n\013instruction\030\002 \001(\0132\026.RemoteContr" + - "olProtocol\"\277\002\n\025RemoteMessageProtocol\022\033\n\004" + + "olProtocol\"\257\002\n\025RemoteMessageProtocol\022\033\n\004" + "uuid\030\001 \002(\0132\r.UuidProtocol\022%\n\tactorInfo\030\002" + " \002(\0132\022.ActorInfoProtocol\022\016\n\006oneWay\030\003 \002(\010" + "\022!\n\007message\030\004 \001(\0132\020.MessageProtocol\022%\n\te" + "xception\030\005 \001(\0132\022.ExceptionProtocol\022%\n\016su" + "pervisorUuid\030\006 \001(\0132\r.UuidProtocol\022\'\n\006sen" + "der\030\007 \001(\0132\027.RemoteActorRefProtocol\022(\n\010me", - "tadata\030\010 \003(\0132\026.MetadataEntryProtocol\022\016\n\006" + - "cookie\030\t \001(\t\"J\n\025RemoteControlProtocol\022\016\n" + - "\006cookie\030\001 \001(\t\022!\n\013commandType\030\002 \002(\0162\014.Com" + - "mandType\":\n\026RemoteActorRefProtocol\022\017\n\007ad" + - "dress\030\001 \002(\t\022\017\n\007timeout\030\002 \001(\004\"\323\002\n\032Seriali" + - "zedActorRefProtocol\022\033\n\004uuid\030\001 \002(\0132\r.Uuid" + - "Protocol\022\017\n\007address\030\002 \002(\t\022\026\n\016actorClassn" + - "ame\030\003 \002(\t\022\025\n\ractorInstance\030\004 \001(\014\022\033\n\023seri" + - "alizerClassname\030\005 \001(\t\022\017\n\007timeout\030\006 \001(\004\022\026" + - "\n\016receiveTimeout\030\007 \001(\004\022%\n\tlifeCycle\030\010 \001(", - "\0132\022.LifeCycleProtocol\022+\n\nsupervisor\030\t \001(" + - "\0132\027.RemoteActorRefProtocol\022\024\n\014hotswapSta" + - "ck\030\n \001(\014\022(\n\010messages\030\013 \003(\0132\026.RemoteMessa" + - "geProtocol\"g\n\037SerializedTypedActorRefPro" + - "tocol\022-\n\010actorRef\030\001 \002(\0132\033.SerializedActo" + - "rRefProtocol\022\025\n\rinterfaceName\030\002 \002(\t\"r\n\017M" + - "essageProtocol\0225\n\023serializationScheme\030\001 " + - "\002(\0162\030.SerializationSchemeType\022\017\n\007message" + - "\030\002 \002(\014\022\027\n\017messageManifest\030\003 \001(\014\"R\n\021Actor" + - "InfoProtocol\022\033\n\004uuid\030\001 \002(\0132\r.UuidProtoco", - "l\022\017\n\007timeout\030\002 \002(\004\022\017\n\007address\030\003 \001(\t\")\n\014U" + - "uidProtocol\022\014\n\004high\030\001 \002(\004\022\013\n\003low\030\002 \002(\004\"3" + - "\n\025MetadataEntryProtocol\022\013\n\003key\030\001 \002(\t\022\r\n\005" + - "value\030\002 \002(\014\"6\n\021LifeCycleProtocol\022!\n\tlife" + - "Cycle\030\001 \002(\0162\016.LifeCycleType\"1\n\017AddressPr" + - "otocol\022\020\n\010hostname\030\001 \002(\t\022\014\n\004port\030\002 \002(\r\"7" + - "\n\021ExceptionProtocol\022\021\n\tclassname\030\001 \002(\t\022\017" + - "\n\007message\030\002 \002(\t*\033\n\013CommandType\022\014\n\010SHUTDO" + - "WN\020\001*]\n\027SerializationSchemeType\022\010\n\004JAVA\020" + - "\001\022\013\n\007SBINARY\020\002\022\016\n\nSCALA_JSON\020\003\022\r\n\tJAVA_J", - "SON\020\004\022\014\n\010PROTOBUF\020\005*-\n\rLifeCycleType\022\r\n\t" + - "PERMANENT\020\001\022\r\n\tTEMPORARY\020\002B\030\n\024akka.remot" + - "e.protocolH\001" + "tadata\030\010 \003(\0132\026.MetadataEntryProtocol\"J\n\025" + + "RemoteControlProtocol\022\016\n\006cookie\030\001 \001(\t\022!\n" + + "\013commandType\030\002 \002(\0162\014.CommandType\":\n\026Remo" + + "teActorRefProtocol\022\017\n\007address\030\001 \002(\t\022\017\n\007t" + + "imeout\030\002 \001(\004\"\323\002\n\032SerializedActorRefProto" + + "col\022\033\n\004uuid\030\001 \002(\0132\r.UuidProtocol\022\017\n\007addr" + + "ess\030\002 \002(\t\022\026\n\016actorClassname\030\003 \002(\t\022\025\n\ract" + + "orInstance\030\004 \001(\014\022\033\n\023serializerClassname\030" + + "\005 \001(\t\022\017\n\007timeout\030\006 \001(\004\022\026\n\016receiveTimeout" + + "\030\007 \001(\004\022%\n\tlifeCycle\030\010 \001(\0132\022.LifeCyclePro", + "tocol\022+\n\nsupervisor\030\t \001(\0132\027.RemoteActorR" + + "efProtocol\022\024\n\014hotswapStack\030\n \001(\014\022(\n\010mess" + + "ages\030\013 \003(\0132\026.RemoteMessageProtocol\"g\n\037Se" + + "rializedTypedActorRefProtocol\022-\n\010actorRe" + + "f\030\001 \002(\0132\033.SerializedActorRefProtocol\022\025\n\r" + + "interfaceName\030\002 \002(\t\"r\n\017MessageProtocol\0225" + + "\n\023serializationScheme\030\001 \002(\0162\030.Serializat" + + "ionSchemeType\022\017\n\007message\030\002 \002(\014\022\027\n\017messag" + + "eManifest\030\003 \001(\014\"R\n\021ActorInfoProtocol\022\033\n\004" + + "uuid\030\001 \002(\0132\r.UuidProtocol\022\017\n\007timeout\030\002 \002", + "(\004\022\017\n\007address\030\003 \001(\t\")\n\014UuidProtocol\022\014\n\004h" + + "igh\030\001 \002(\004\022\013\n\003low\030\002 \002(\004\"3\n\025MetadataEntryP" + + "rotocol\022\013\n\003key\030\001 \002(\t\022\r\n\005value\030\002 \002(\014\"6\n\021L" + + "ifeCycleProtocol\022!\n\tlifeCycle\030\001 \002(\0162\016.Li" + + "feCycleType\"1\n\017AddressProtocol\022\020\n\010hostna" + + "me\030\001 \002(\t\022\014\n\004port\030\002 \002(\r\"7\n\021ExceptionProto" + + "col\022\021\n\tclassname\030\001 \002(\t\022\017\n\007message\030\002 \002(\t*" + + "(\n\013CommandType\022\013\n\007CONNECT\020\001\022\014\n\010SHUTDOWN\020" + + "\002*]\n\027SerializationSchemeType\022\010\n\004JAVA\020\001\022\013" + + "\n\007SBINARY\020\002\022\016\n\nSCALA_JSON\020\003\022\r\n\tJAVA_JSON", + "\020\004\022\014\n\010PROTOBUF\020\005*-\n\rLifeCycleType\022\r\n\tPER" + + "MANENT\020\001\022\r\n\tTEMPORARY\020\002B\030\n\024akka.remote.p" + + "rotocolH\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { @@ -5769,7 +5729,7 @@ public final class RemoteProtocol { internal_static_RemoteMessageProtocol_fieldAccessorTable = new com.google.protobuf.GeneratedMessage.FieldAccessorTable( internal_static_RemoteMessageProtocol_descriptor, - new java.lang.String[] { "Uuid", "ActorInfo", "OneWay", "Message", "Exception", "SupervisorUuid", "Sender", "Metadata", "Cookie", }, + new java.lang.String[] { "Uuid", "ActorInfo", "OneWay", "Message", "Exception", "SupervisorUuid", "Sender", "Metadata", }, akka.remote.protocol.RemoteProtocol.RemoteMessageProtocol.class, akka.remote.protocol.RemoteProtocol.RemoteMessageProtocol.Builder.class); internal_static_RemoteControlProtocol_descriptor = diff --git a/akka-remote/src/main/protocol/RemoteProtocol.proto b/akka-remote/src/main/protocol/RemoteProtocol.proto index 37dce1914e..3b51a6e6d8 100644 --- a/akka-remote/src/main/protocol/RemoteProtocol.proto +++ b/akka-remote/src/main/protocol/RemoteProtocol.proto @@ -28,7 +28,6 @@ message RemoteMessageProtocol { optional UuidProtocol supervisorUuid = 6; optional RemoteActorRefProtocol sender = 7; repeated MetadataEntryProtocol metadata = 8; - optional string cookie = 9; } /** @@ -43,7 +42,8 @@ message RemoteControlProtocol { * Defines the type of the RemoteControlProtocol command type */ enum CommandType { - SHUTDOWN = 1; + CONNECT = 1; + SHUTDOWN = 2; } /** diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index 01a80fbc9b..de04a44a0f 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -161,7 +161,6 @@ abstract class RemoteClient private[akka] ( } private[remote] val runSwitch = new Switch() - private[remote] val isAuthenticated = new AtomicBoolean(false) private[remote] def isRunning = runSwitch.isOn @@ -196,18 +195,10 @@ abstract class RemoteClient private[akka] ( remoteAddress: InetSocketAddress, timeout: Long, isOneWay: Boolean, - actorRef: ActorRef): Option[CompletableFuture[T]] = synchronized { // FIXME: find better strategy to prevent race - + actorRef: ActorRef): Option[CompletableFuture[T]] = send(createRemoteMessageProtocolBuilder( - Some(actorRef), - Left(actorRef.uuid), - actorRef.address, - timeout, - Right(message), - isOneWay, - senderOption, - if (isAuthenticated.compareAndSet(false, true)) RemoteClientSettings.SECURE_COOKIE else None).build, senderFuture) - } + Some(actorRef), Left(actorRef.uuid), actorRef.address, timeout, Right(message), isOneWay, senderOption).build, + senderFuture) /** * Sends the message across the wire @@ -342,6 +333,14 @@ class ActiveRemoteClient private[akka] ( notifyListeners(RemoteClientError(connection.getCause, module, remoteAddress)) false } else { + + //Send cookie + val handshake = RemoteControlProtocol.newBuilder.setCommandType(CommandType.CONNECT) + if (SECURE_COOKIE.nonEmpty) + handshake.setCookie(SECURE_COOKIE.get) + + connection.getChannel.write(RemoteEncoder.encode(handshake.build)) + //Add a task that does GCing of expired Futures timer.newTimeout(new TimerTask() { def run(timeout: Timeout) = { @@ -361,7 +360,6 @@ class ActiveRemoteClient private[akka] ( } match { case true ⇒ true case false if reconnectIfAlreadyConnected ⇒ - isAuthenticated.set(false) openChannels.remove(connection.getChannel) connection.getChannel.close connection = bootstrap.connect(remoteAddress) @@ -369,7 +367,15 @@ class ActiveRemoteClient private[akka] ( if (!connection.isSuccess) { notifyListeners(RemoteClientError(connection.getCause, module, remoteAddress)) false - } else true + } else { + //Send cookie + val handshake = RemoteControlProtocol.newBuilder.setCommandType(CommandType.CONNECT) + if (SECURE_COOKIE.nonEmpty) + handshake.setCookie(SECURE_COOKIE.get) + + connection.getChannel.write(RemoteEncoder.encode(handshake.build)) + true + } case false ⇒ false } } @@ -577,10 +583,10 @@ class NettyRemoteServer(serverModule: NettyRemoteServerModule, val host: String, def shutdown() { try { val shutdownSignal = { - val b = RemoteControlProtocol.newBuilder + val b = RemoteControlProtocol.newBuilder.setCommandType(CommandType.SHUTDOWN) if (RemoteClientSettings.SECURE_COOKIE.nonEmpty) b.setCookie(RemoteClientSettings.SECURE_COOKIE.get) - b.setCommandType(CommandType.SHUTDOWN) + b.build } openChannels.write(RemoteEncoder.encode(shutdownSignal)).awaitUninterruptibly @@ -736,12 +742,39 @@ class RemoteServerPipelineFactory( MAX_TOTAL_MEMORY_SIZE, EXECUTION_POOL_KEEPALIVE.length, EXECUTION_POOL_KEEPALIVE.unit)) + val authenticator = if (REQUIRE_COOKIE) new RemoteServerAuthenticationHandler(SECURE_COOKIE) :: Nil else Nil val remoteServer = new RemoteServerHandler(name, openChannels, loader, server) - val stages: List[ChannelHandler] = dec ::: lenDec :: protobufDec :: enc ::: lenPrep :: protobufEnc :: execution :: remoteServer :: Nil + val stages: List[ChannelHandler] = dec ::: lenDec :: protobufDec :: enc ::: lenPrep :: protobufEnc :: execution :: authenticator ::: remoteServer :: Nil new StaticChannelPipeline(stages: _*) } } +@ChannelHandler.Sharable +class RemoteServerAuthenticationHandler(secureCookie: Option[String]) extends SimpleChannelUpstreamHandler { + val authenticated = new AnyRef + + override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = secureCookie match { + case None ⇒ ctx.sendUpstream(event) + case Some(cookie) ⇒ + ctx.getAttachment match { + case `authenticated` ⇒ ctx.sendUpstream(event) + case null ⇒ event.getMessage match { + case remoteProtocol: AkkaRemoteProtocol if remoteProtocol.hasInstruction ⇒ + remoteProtocol.getInstruction.getCookie match { + case `cookie` ⇒ + ctx.setAttachment(authenticated) + ctx.sendUpstream(event) + case _ ⇒ + throw new SecurityException( + "The remote client [" + ctx.getChannel.getRemoteAddress + "] secure cookie is not the same as remote server secure cookie") + } + case _ ⇒ + throw new SecurityException("The remote client [" + ctx.getChannel.getRemoteAddress + "] is not Authorized!") + } + } + } +} + /** * @author Jonas Bonér */ @@ -752,7 +785,6 @@ class RemoteServerHandler( val applicationLoader: Option[ClassLoader], val server: NettyRemoteServerModule) extends SimpleChannelUpstreamHandler { import RemoteServerSettings._ - val CHANNEL_INIT = "channel-init".intern applicationLoader.foreach(MessageSerializer.setClassLoader(_)) //TODO: REVISIT: THIS FEELS A BIT DODGY @@ -786,7 +818,6 @@ class RemoteServerHandler( val clientAddress = getClientAddress(ctx) sessionActors.set(event.getChannel(), new ConcurrentHashMap[String, ActorRef]()) server.notifyListeners(RemoteServerClientConnected(server, clientAddress)) - if (REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication } override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { @@ -810,11 +841,8 @@ class RemoteServerHandler( override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = event.getMessage match { case null ⇒ throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event) - //case remoteProtocol: AkkaRemoteProtocol if remoteProtocol.hasInstruction => RemoteServer cannot receive control messages (yet) - case remoteProtocol: AkkaRemoteProtocol if remoteProtocol.hasMessage ⇒ - val requestProtocol = remoteProtocol.getMessage - if (REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx) - handleRemoteMessageProtocol(requestProtocol, event.getChannel) + case remote: AkkaRemoteProtocol if remote.hasMessage ⇒ handleRemoteMessageProtocol(remote.getMessage, event.getChannel) + //case remote: AkkaRemoteProtocol if remote.hasInstruction => RemoteServer cannot receive control messages (yet) case _ ⇒ //ignore } @@ -874,8 +902,7 @@ class RemoteServerHandler( actorInfo.getTimeout, r, true, - Some(actorRef), - None) + Some(actorRef)) // FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) @@ -939,26 +966,11 @@ class RemoteServerHandler( actorInfo.getTimeout, Left(exception), true, - None, None) if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) RemoteEncoder.encode(messageBuilder.build) } - private def authenticateRemoteClient(request: RemoteMessageProtocol, ctx: ChannelHandlerContext) = { - val attachment = ctx.getAttachment - if ((attachment ne null) && - attachment.isInstanceOf[String] && - attachment.asInstanceOf[String] == CHANNEL_INIT) { // is first time around, channel initialization - ctx.setAttachment(null) - val clientAddress = ctx.getChannel.getRemoteAddress.toString - if (!request.hasCookie) throw new SecurityException( - "The remote client [" + clientAddress + "] does not have a secure cookie.") - if (!(request.getCookie == SECURE_COOKIE.get)) throw new SecurityException( - "The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie") - } - } - protected def parseUuid(protocol: UuidProtocol): Uuid = uuidFrom(protocol.getHigh, protocol.getLow) } diff --git a/akka-remote/src/main/scala/akka/serialization/SerializationProtocol.scala b/akka-remote/src/main/scala/akka/serialization/SerializationProtocol.scala index 4a19cdd436..a0705d7a52 100644 --- a/akka-remote/src/main/scala/akka/serialization/SerializationProtocol.scala +++ b/akka-remote/src/main/scala/akka/serialization/SerializationProtocol.scala @@ -78,10 +78,9 @@ object ActorSerialization { actorRef.timeout, Right(m.message), false, - actorRef.getSender, - RemoteClientSettings.SECURE_COOKIE).build) + actorRef.getSender)) - requestProtocols.foreach(rp ⇒ builder.addMessages(rp)) + requestProtocols.foreach(builder.addMessages(_)) } actorRef.receiveTimeout.foreach(builder.setReceiveTimeout(_)) @@ -201,8 +200,7 @@ object RemoteActorSerialization { timeout: Long, message: Either[Throwable, Any], isOneWay: Boolean, - senderOption: Option[ActorRef], - secureCookie: Option[String]): RemoteMessageProtocol.Builder = { + senderOption: Option[ActorRef]): RemoteMessageProtocol.Builder = { val uuidProtocol = replyUuid match { case Left(uid) ⇒ UuidProtocol.newBuilder.setHigh(uid.getTime).setLow(uid.getClockSeqAndNode).build @@ -238,8 +236,6 @@ object RemoteActorSerialization { case s ⇒ s } - secureCookie.foreach(messageBuilder.setCookie(_)) - /* TODO invent new supervision strategy actorRef.foreach { ref => ref.registerSupervisorAsRemoteActor.foreach { id =>