diff --git a/akka-core/src/main/scala/remote/RemoteClient.scala b/akka-core/src/main/scala/remote/RemoteClient.scala index 4c18dcc6c8..60cbfbdd36 100644 --- a/akka-core/src/main/scala/remote/RemoteClient.scala +++ b/akka-core/src/main/scala/remote/RemoteClient.scala @@ -19,6 +19,7 @@ import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} import org.jboss.netty.handler.codec.protobuf.{ProtobufDecoder, ProtobufEncoder} import org.jboss.netty.handler.timeout.ReadTimeoutHandler import org.jboss.netty.util.{TimerTask, Timeout, HashedWheelTimer} +import org.jboss.netty.handler.ssl.SslHandler import java.net.{SocketAddress, InetSocketAddress} import java.util.concurrent.{TimeUnit, Executors, ConcurrentMap, ConcurrentHashMap, ConcurrentSkipListSet} @@ -249,6 +250,10 @@ class RemoteClientPipelineFactory(name: String, timer: HashedWheelTimer, client: RemoteClient) extends ChannelPipelineFactory { def getPipeline: ChannelPipeline = { + val engine = RemoteServerSslContext.client.createSSLEngine() + engine.setUseClientMode(true) + + val ssl = new SslHandler(engine) val timeout = new ReadTimeoutHandler(timer, RemoteClient.READ_TIMEOUT) val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4) val lenPrep = new LengthFieldPrepender(4) @@ -262,8 +267,8 @@ class RemoteClientPipelineFactory(name: String, val remoteClient = new RemoteClientHandler(name, futures, supervisors, bootstrap, remoteAddress, timer, client) val stages: Array[ChannelHandler] = - zipCodec.map(codec => Array(timeout, codec.decoder, lenDec, protobufDec, codec.encoder, lenPrep, protobufEnc, remoteClient)) - .getOrElse(Array(timeout, lenDec, protobufDec, lenPrep, protobufEnc, remoteClient)) + zipCodec.map(codec => Array(ssl, timeout, codec.decoder, lenDec, protobufDec, codec.encoder, lenPrep, protobufEnc, remoteClient)) + .getOrElse(Array(ssl, timeout, lenDec, protobufDec, lenPrep, protobufEnc, remoteClient)) new StaticChannelPipeline(stages: _*) } } @@ -342,9 +347,20 @@ class RemoteClientHandler(val name: String, } override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - client.listeners.toArray.foreach(l => - l.asInstanceOf[ActorRef] ! RemoteClientConnected(client.hostname, client.port)) - log.debug("Remote client connected to [%s]", ctx.getChannel.getRemoteAddress) + +// client.listeners.toArray.foreach(l => +// l.asInstanceOf[ActorRef] ! RemoteClientConnected(client.hostname, client.port)) +// log.debug("Remote client connected to [%s]", ctx.getChannel.getRemoteAddress) + + val sslHandler : SslHandler = ctx.getPipeline.get(classOf[SslHandler]) + sslHandler.handshake().addListener( new ChannelFutureListener { + def operationComplete(future : ChannelFuture) : Unit = { + if(future.isSuccess) { + client.listeners.toArray.foreach(l => l.asInstanceOf[ActorRef] ! RemoteClientConnected(client.hostname, client.port)) + log.debug("Remote client connected to [%s]", ctx.getChannel.getRemoteAddress) + } + } + }) } override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { diff --git a/akka-core/src/main/scala/remote/RemoteServer.scala b/akka-core/src/main/scala/remote/RemoteServer.scala index 28d087b3e1..08c402f15a 100644 --- a/akka-core/src/main/scala/remote/RemoteServer.scala +++ b/akka-core/src/main/scala/remote/RemoteServer.scala @@ -22,6 +22,8 @@ import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.frame.{LengthFieldBasedFrameDecoder, LengthFieldPrepender} import org.jboss.netty.handler.codec.protobuf.{ProtobufDecoder, ProtobufEncoder} import org.jboss.netty.handler.codec.compression.{ZlibEncoder, ZlibDecoder} +import org.jboss.netty.handler.ssl.SslHandler + import scala.collection.mutable.Map @@ -279,6 +281,32 @@ class RemoteServer extends Logging { case class Codec(encoder: ChannelHandler, decoder: ChannelHandler) +object RemoteServerSslContext { + import java.security.{KeyStore,Security} + import javax.net.ssl.{KeyManager,KeyManagerFactory,SSLContext,TrustManagerFactory} + + val (client,server) = { + val protocol = "TLS" + val algorithm = Option(Security.getProperty("ssl.KeyManagerFactory.algorithm")).getOrElse("SunX509") + val store = KeyStore.getInstance("JKS") + store.load(getClass.getResourceAsStream("keystore"),"keystorepassword".toCharArray) + + val keyMan = KeyManagerFactory.getInstance(algorithm) + keyMan.init(store, "certificatepassword".toCharArray) + + val trustMan = TrustManagerFactory.getInstance("SunX509"); + trustMan.init(store) //TODO safe to use same keystore? Or should use it's own keystore? + + val s = SSLContext.getInstance(protocol) + s.init(keyMan.getKeyManagers, null, null) + + val c = SSLContext.getInstance(protocol) + c.init(null, trustMan.getTrustManagers, null) + + (c,s) + } +} + /** * @author Jonas Bonér */ @@ -291,6 +319,10 @@ class RemoteServerPipelineFactory( import RemoteServer._ def getPipeline: ChannelPipeline = { + val engine = RemoteServerSslContext.server.createSSLEngine() + engine.setUseClientMode(false) + + val ssl = new SslHandler(engine) val lenDec = new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4) val lenPrep = new LengthFieldPrepender(4) val protobufDec = new ProtobufDecoder(RemoteRequestProtocol.getDefaultInstance) @@ -303,8 +335,8 @@ class RemoteServerPipelineFactory( val remoteServer = new RemoteServerHandler(name, openChannels, loader, actors, activeObjects) val stages: Array[ChannelHandler] = - zipCodec.map(codec => Array(codec.decoder, lenDec, protobufDec, codec.encoder, lenPrep, protobufEnc, remoteServer)) - .getOrElse(Array(lenDec, protobufDec, lenPrep, protobufEnc, remoteServer)) + zipCodec.map(codec => Array(ssl,codec.decoder, lenDec, protobufDec, codec.encoder, lenPrep, protobufEnc, remoteServer)) + .getOrElse(Array(ssl,lenDec, protobufDec, lenPrep, protobufEnc, remoteServer)) new StaticChannelPipeline(stages: _*) } } @@ -330,6 +362,14 @@ class RemoteServerHandler( override def channelOpen(ctx: ChannelHandlerContext, event: ChannelStateEvent) { openChannels.add(ctx.getChannel) } + + override def channelConnected(ctx : ChannelHandlerContext, e : ChannelStateEvent) { + val sslHandler : SslHandler = ctx.getPipeline.get(classOf[SslHandler]) + + // Begin handshake. + sslHandler.handshake() + } + override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = { if (event.isInstanceOf[ChannelStateEvent] &&