From 580a2484a96115e8bcef36dd45e9f37fcba41d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Endre=20S=C3=A1ndor=20Varga?= Date: Mon, 21 Jan 2013 12:40:42 +0100 Subject: [PATCH] Added blocking handshake for outbound SSL connections #2833 - Eliminated "Promise passing style" and using future composition wherever possible --- .../transport/netty/NettyTransport.scala | 81 ++++++++++--------- .../remote/transport/netty/TcpSupport.scala | 18 ++--- .../remote/transport/netty/UdpSupport.scala | 3 +- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala index 3a09a1275e..92e3b48fcd 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/NettyTransport.scala @@ -23,8 +23,8 @@ import org.jboss.netty.channel.socket.nio.{ NioDatagramChannelFactory, NioServer import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender } import org.jboss.netty.handler.ssl.SslHandler import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS } -import scala.concurrent.{ ExecutionContext, Promise, Future } -import scala.util.Try +import scala.concurrent.{ ExecutionContext, Promise, Future, blocking } +import scala.util.{ Failure, Success, Try } import util.control.{ NoStackTrace, NonFatal } object NettyTransportSettings { @@ -167,12 +167,12 @@ abstract class ServerHandler(protected final val transport: NettyTransport, } -abstract class ClientHandler(protected final val transport: NettyTransport, - private final val statusPromise: Promise[AssociationHandle]) +abstract class ClientHandler(protected final val transport: NettyTransport) extends NettyClientHelpers with CommonHandlers { + final protected val statusPromise = Promise[AssociationHandle]() + def statusFuture = statusPromise.future final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = { - channel.setReadable(false) init(channel, remoteSocketAddress, msg)(statusPromise.success) } @@ -256,7 +256,7 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s private def sslHandler(isClient: Boolean): SslHandler = { val handler = NettySSLSupport(settings.SslSettings.get, log, isClient) - if (isClient) handler.setIssueHandshake(true) + handler.setCloseOnSSLException(true) handler } @@ -271,16 +271,17 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s } } - private def clientPipelineFactory(statusPromise: Promise[AssociationHandle]): ChannelPipelineFactory = new ChannelPipelineFactory { - override def getPipeline: ChannelPipeline = { - val pipeline = newPipeline - if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true)) - val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, statusPromise) - else new TcpClientHandler(NettyTransport.this, statusPromise) - pipeline.addLast("clienthandler", handler) - pipeline + private val clientPipelineFactory: ChannelPipelineFactory = + new ChannelPipelineFactory { + override def getPipeline: ChannelPipeline = { + val pipeline = newPipeline + if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true)) + val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this) + else new TcpClientHandler(NettyTransport.this) + pipeline.addLast("clienthandler", handler) + pipeline + } } - } private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = { // FIXME: Expose these settings in configuration @@ -302,8 +303,8 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory) } - private def outboundBootstrap(statusPromise: Promise[AssociationHandle]): ClientBootstrap = { - val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(statusPromise)) + private def outboundBootstrap: ClientBootstrap = { + val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory) bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis) bootstrap } @@ -346,34 +347,36 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s override def associate(remoteAddress: Address): Future[AssociationHandle] = { if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound")) else { - val statusPromise = Promise[AssociationHandle]() - (try { - val f = NettyFutureBridge(outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress))) recover { - case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled") - } + val bootstrap: ClientBootstrap = outboundBootstrap - if (isDatagram) - f map { channel ⇒ - channel.getRemoteAddress match { + (for { + readyChannel ← NettyFutureBridge(bootstrap.connect(addressToSocketAddress(remoteAddress))) map { + channel ⇒ + if (EnableSsl) + blocking { + channel.getPipeline.get[SslHandler](classOf[SslHandler]).handshake().awaitUninterruptibly() + } + if (!isDatagram) channel.setReadable(false) + channel + } + handle ← if (isDatagram) + Future { + readyChannel.getRemoteAddress match { case addr: InetSocketAddress ⇒ - val handle = new UdpAssociationHandle(localAddress, remoteAddress, channel, NettyTransport.this) - statusPromise.success(handle) - handle.readHandlerPromise.future.onSuccess { case listener ⇒ udpConnectionTable.put(addr, listener) } + val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this) + handle.readHandlerPromise.future.onSuccess { + case listener ⇒ udpConnectionTable.put(addr, listener) + } + handle case unknown ⇒ throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}") } } - else f - } catch { - case e @ (_: UnknownHostException | _: SecurityException | _: IllegalArgumentException) ⇒ - Future.failed(InvalidAssociationException("Invalid association ", e)) - case NonFatal(e) ⇒ - Future.failed(e) - }) onFailure { - case t: ConnectException ⇒ statusPromise failure new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace - case t ⇒ statusPromise failure t + else + readyChannel.getPipeline.get[ClientHandler](classOf[ClientHandler]).statusFuture + } yield handle) recover { + case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled") with NoStackTrace + case NonFatal(t) ⇒ throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace } - - statusPromise.future } } diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala index fc02383337..4cded02003 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/TcpSupport.scala @@ -12,6 +12,7 @@ import java.net.InetSocketAddress import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer } import org.jboss.netty.channel._ import scala.concurrent.{ Future, Promise } +import scala.util.{ Success, Failure } private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] { override def initialValue(channel: Channel): Option[HandleEventListener] = None @@ -30,16 +31,15 @@ private[remote] trait TcpHandlers extends CommonHandlers { override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle = new TcpAssociationHandle(localAddress, remoteAddress, channel) - override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = notifyListener(e.getChannel, Disassociated) - } - override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) { + override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = { val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array() if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes))) } - override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent) { + override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = { notifyListener(e.getChannel, Disassociated) e.getChannel.close() // No graceful close here } @@ -48,18 +48,16 @@ private[remote] trait TcpHandlers extends CommonHandlers { private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener]) extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers { - override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = initInbound(e.getChannel, e.getChannel.getRemoteAddress, null) - } } -private[remote] class TcpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle]) - extends ClientHandler(_transport, _statusPromise) with TcpHandlers { +private[remote] class TcpClientHandler(_transport: NettyTransport) + extends ClientHandler(_transport) with TcpHandlers { - override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { + override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null) - } } diff --git a/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala b/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala index 0b20b25fb7..b00724e430 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/netty/UdpSupport.scala @@ -53,8 +53,7 @@ private[remote] class UdpServerHandler(_transport: NettyTransport, _associationL initInbound(channel, remoteSocketAddress, msg) } -private[remote] class UdpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle]) - extends ClientHandler(_transport, _statusPromise) with UdpHandlers { +private[remote] class UdpClientHandler(_transport: NettyTransport) extends ClientHandler(_transport) with UdpHandlers { override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = initOutbound(channel, remoteSocketAddress, msg)