Completed Erlang-style cookie handshake between RemoteClient and RemoteServer

This commit is contained in:
Jonas Bonér 2010-10-26 15:23:50 +02:00
commit 52f5e5e861
9 changed files with 187 additions and 211 deletions

View file

@ -212,13 +212,14 @@ class RemoteServer extends Logging with ListenerManagement {
address = Address(_hostname,_port)
log.info("Starting remote server at [%s:%s]", hostname, port)
RemoteServer.register(hostname, port, this)
val pipelineFactory = new RemoteServerPipelineFactory(
name, openChannels, loader, this)
val pipelineFactory = new RemoteServerPipelineFactory(name, openChannels, loader, this)
bootstrap.setPipelineFactory(pipelineFactory)
bootstrap.setOption("child.tcpNoDelay", true)
bootstrap.setOption("child.keepAlive", true)
bootstrap.setOption("child.reuseAddress", true)
bootstrap.setOption("child.connectTimeoutMillis", RemoteServer.CONNECTION_TIMEOUT_MILLIS.toMillis)
openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port)))
_isRunning = true
Cluster.registerLocalNode(hostname, port)
@ -260,11 +261,8 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized {
log.debug("Registering server side remote typed actor [%s] with id [%s]", typedActor.getClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid())
} else {
registerTypedActor(id, typedActor, typedActors())
}
if (id.startsWith(UUID_PREFIX)) registerTypedActor(id.substring(UUID_PREFIX.length), typedActor, typedActorsByUuid)
else registerTypedActor(id, typedActor, typedActors)
}
/**
@ -279,28 +277,19 @@ class RemoteServer extends Logging with ListenerManagement {
*/
def register(id: String, actorRef: ActorRef): Unit = synchronized {
log.debug("Registering server side remote actor [%s] with id [%s]", actorRef.actorClass.getName, id)
if (id.startsWith(UUID_PREFIX)) {
register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid())
} else {
register(id, actorRef, actors())
}
if (id.startsWith(UUID_PREFIX)) register(id.substring(UUID_PREFIX.length), actorRef, actorsByUuid)
else register(id, actorRef, actors)
}
private def register[Key](id: Key, actorRef: ActorRef, registry: ConcurrentHashMap[Key, ActorRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
if (_isRunning && !registry.contains(id)) {
if (!actorRef.isRunning) actorRef.start
registry.put(id, actorRef)
}
}
private def registerTypedActor[Key](id: Key, typedActor: AnyRef, registry: ConcurrentHashMap[Key, AnyRef]) {
if (_isRunning) {
if (!registry.contains(id)) {
registry.put(id, typedActor)
}
}
if (_isRunning && !registry.contains(id)) registry.put(id, typedActor)
}
/**
@ -309,8 +298,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(actorRef: ActorRef):Unit = synchronized {
if (_isRunning) {
log.debug("Unregistering server side remote actor [%s] with id [%s:%s]", actorRef.actorClass.getName, actorRef.id, actorRef.uuid)
actors().remove(actorRef.id,actorRef)
actorsByUuid().remove(actorRef.uuid,actorRef)
actors.remove(actorRef.id, actorRef)
actorsByUuid.remove(actorRef.uuid, actorRef)
}
}
@ -322,12 +311,11 @@ class RemoteServer extends Logging with ListenerManagement {
def unregister(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
actorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
val actorRef = actors() get id
actorsByUuid().remove(actorRef.uuid,actorRef)
actors().remove(id,actorRef)
if (id.startsWith(UUID_PREFIX)) actorsByUuid.remove(id.substring(UUID_PREFIX.length))
else {
val actorRef = actors get id
actorsByUuid.remove(actorRef.uuid, actorRef)
actors.remove(id,actorRef)
}
}
}
@ -340,11 +328,8 @@ class RemoteServer extends Logging with ListenerManagement {
def unregisterTypedActor(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote typed actor with id [%s]", id)
if (id.startsWith(UUID_PREFIX)) {
typedActorsByUuid().remove(id.substring(UUID_PREFIX.length))
} else {
typedActors().remove(id)
}
if (id.startsWith(UUID_PREFIX)) typedActorsByUuid.remove(id.substring(UUID_PREFIX.length))
else typedActors.remove(id)
}
}
@ -352,10 +337,10 @@ class RemoteServer extends Logging with ListenerManagement {
protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
private[akka] def actors() = ActorRegistry.actors(address)
private[akka] def actorsByUuid() = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors() = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid() = ActorRegistry.typedActorsByUuid(address)
private[akka] def actors = ActorRegistry.actors(address)
private[akka] def actorsByUuid = ActorRegistry.actorsByUuid(address)
private[akka] def typedActors = ActorRegistry.typedActors(address)
private[akka] def typedActorsByUuid = ActorRegistry.typedActorsByUuid(address)
}
object RemoteServerSslContext {
@ -419,6 +404,7 @@ class RemoteServerHandler(
val applicationLoader: Option[ClassLoader],
val server: RemoteServer) extends SimpleChannelUpstreamHandler with Logging {
import RemoteServer._
val AW_PROXY_PREFIX = "$$ProxiedByAW".intern
val CHANNEL_INIT = "channel-init".intern
@ -444,10 +430,8 @@ class RemoteServerHandler(
} else future.getChannel.close
}
})
} else {
server.notifyListeners(RemoteServerClientConnected(server))
}
if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT)
} else server.notifyListeners(RemoteServerClientConnected(server))
if (RemoteServer.REQUIRE_COOKIE) ctx.setAttachment(CHANNEL_INIT) // signal that this is channel initialization, which will need authentication
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
@ -467,7 +451,7 @@ class RemoteServerHandler(
if (message eq null) throw new IllegalActorStateException("Message in remote MessageEvent is null: " + event)
if (message.isInstanceOf[RemoteRequestProtocol]) {
val requestProtocol = message.asInstanceOf[RemoteRequestProtocol]
authenticateRemoteClient(requestProtocol, ctx)
if (RemoteServer.REQUIRE_COOKIE) authenticateRemoteClient(requestProtocol, ctx)
handleRemoteRequestProtocol(requestProtocol, event.getChannel)
}
}
@ -521,8 +505,7 @@ class RemoteServerHandler(
try {
channel.write(replyBuilder.build)
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
@ -530,8 +513,7 @@ class RemoteServerHandler(
try {
channel.write(createErrorReplyMessage(exception, request, true))
} catch {
case e: Throwable =>
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
}
@ -543,8 +525,8 @@ class RemoteServerHandler(
val actorInfo = request.getActorInfo
val typedActorInfo = actorInfo.getTypedActorInfo
log.debug("Dispatching to remote typed actor [%s :: %s]", typedActorInfo.getMethod, typedActorInfo.getInterface)
val typedActor = createTypedActor(actorInfo)
val typedActor = createTypedActor(actorInfo)
val args = MessageSerializer.deserialize(request.getMessage).asInstanceOf[Array[AnyRef]].toList
val argClasses = args.map(_.getClass)
@ -566,49 +548,39 @@ class RemoteServerHandler(
case e: InvocationTargetException =>
channel.write(createErrorReplyMessage(e.getCause, request, false))
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable =>
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, false))
server.notifyListeners(RemoteServerError(e, server))
}
}
private def findActorById(id: String) : ActorRef = {
server.actors().get(id)
server.actors.get(id)
}
private def findActorByUuid(uuid: String) : ActorRef = {
server.actorsByUuid().get(uuid)
server.actorsByUuid.get(uuid)
}
private def findTypedActorById(id: String) : AnyRef = {
server.typedActors().get(id)
server.typedActors.get(id)
}
private def findTypedActorByUuid(uuid: String) : AnyRef = {
server.typedActorsByUuid().get(uuid)
server.typedActorsByUuid.get(uuid)
}
private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findActorByUuid(id.substring(UUID_PREFIX.length))
else findActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findActorByUuid(uuid)
actorRefOrNull
}
private def findTypedActorByIdOrUuid(id: String, uuid: String) : AnyRef = {
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) {
findTypedActorByUuid(id.substring(UUID_PREFIX.length))
} else {
findTypedActorById(id)
}
if (actorRefOrNull eq null) {
actorRefOrNull = findTypedActorByUuid(uuid)
}
var actorRefOrNull = if (id.startsWith(UUID_PREFIX)) findTypedActorByUuid(id.substring(UUID_PREFIX.length))
else findTypedActorById(id)
if (actorRefOrNull eq null) actorRefOrNull = findTypedActorByUuid(uuid)
actorRefOrNull
}
@ -694,18 +666,17 @@ class RemoteServerHandler(
}
private def authenticateRemoteClient(request: RemoteRequestProtocol, ctx: ChannelHandlerContext) = {
if (RemoteServer.REQUIRE_COOKIE) {
val attachment = ctx.getAttachment
if ((attachment ne null) &&
attachment.isInstanceOf[String] &&
attachment.asInstanceOf[String] == CHANNEL_INIT) {
val clientAddress = ctx.getChannel.getRemoteAddress.toString
if (!request.hasCookie) throw new SecurityException(
"The remote client [" + clientAddress + "] does not have a secure cookie.")
if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException(
"The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie")
log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress)
}
val attachment = ctx.getAttachment
if ((attachment ne null) &&
attachment.isInstanceOf[String] &&
attachment.asInstanceOf[String] == CHANNEL_INIT) { // is first time around, channel initialization
ctx.setAttachment(null)
val clientAddress = ctx.getChannel.getRemoteAddress.toString
if (!request.hasCookie) throw new SecurityException(
"The remote client [" + clientAddress + "] does not have a secure cookie.")
if (!(request.getCookie == RemoteServer.SECURE_COOKIE.get)) throw new SecurityException(
"The remote client [" + clientAddress + "] secure cookie is not the same as remote server secure cookie")
log.info("Remote client [%s] successfully authenticated using secure cookie", clientAddress)
}
}
}