Added blocking handshake for outbound SSL connections #2833

- Eliminated "Promise passing style" and using future composition wherever possible
This commit is contained in:
Endre Sándor Varga 2013-01-21 12:40:42 +01:00
parent 50fb495e56
commit 580a2484a9
3 changed files with 51 additions and 51 deletions

View file

@ -23,8 +23,8 @@ import org.jboss.netty.channel.socket.nio.{ NioDatagramChannelFactory, NioServer
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender } import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.ssl.SslHandler import org.jboss.netty.handler.ssl.SslHandler
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS } import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
import scala.concurrent.{ ExecutionContext, Promise, Future } import scala.concurrent.{ ExecutionContext, Promise, Future, blocking }
import scala.util.Try import scala.util.{ Failure, Success, Try }
import util.control.{ NoStackTrace, NonFatal } import util.control.{ NoStackTrace, NonFatal }
object NettyTransportSettings { object NettyTransportSettings {
@ -167,12 +167,12 @@ abstract class ServerHandler(protected final val transport: NettyTransport,
} }
abstract class ClientHandler(protected final val transport: NettyTransport, abstract class ClientHandler(protected final val transport: NettyTransport)
private final val statusPromise: Promise[AssociationHandle])
extends NettyClientHelpers with CommonHandlers { extends NettyClientHelpers with CommonHandlers {
final protected val statusPromise = Promise[AssociationHandle]()
def statusFuture = statusPromise.future
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = { final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
channel.setReadable(false)
init(channel, remoteSocketAddress, msg)(statusPromise.success) init(channel, remoteSocketAddress, msg)(statusPromise.success)
} }
@ -256,7 +256,7 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
private def sslHandler(isClient: Boolean): SslHandler = { private def sslHandler(isClient: Boolean): SslHandler = {
val handler = NettySSLSupport(settings.SslSettings.get, log, isClient) val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
if (isClient) handler.setIssueHandshake(true) handler.setCloseOnSSLException(true)
handler handler
} }
@ -271,16 +271,17 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
} }
} }
private def clientPipelineFactory(statusPromise: Promise[AssociationHandle]): ChannelPipelineFactory = new ChannelPipelineFactory { private val clientPipelineFactory: ChannelPipelineFactory =
override def getPipeline: ChannelPipeline = { new ChannelPipelineFactory {
val pipeline = newPipeline override def getPipeline: ChannelPipeline = {
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true)) val pipeline = newPipeline
val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, statusPromise) if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
else new TcpClientHandler(NettyTransport.this, statusPromise) val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this)
pipeline.addLast("clienthandler", handler) else new TcpClientHandler(NettyTransport.this)
pipeline pipeline.addLast("clienthandler", handler)
pipeline
}
} }
}
private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = { private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
// FIXME: Expose these settings in configuration // FIXME: Expose these settings in configuration
@ -302,8 +303,8 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
case Udp setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory) case Udp setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
} }
private def outboundBootstrap(statusPromise: Promise[AssociationHandle]): ClientBootstrap = { private def outboundBootstrap: ClientBootstrap = {
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(statusPromise)) val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory)
bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis) bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
bootstrap bootstrap
} }
@ -346,34 +347,36 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
override def associate(remoteAddress: Address): Future[AssociationHandle] = { override def associate(remoteAddress: Address): Future[AssociationHandle] = {
if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound")) if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
else { else {
val statusPromise = Promise[AssociationHandle]() val bootstrap: ClientBootstrap = outboundBootstrap
(try {
val f = NettyFutureBridge(outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress))) recover {
case c: CancellationException throw new NettyTransportException("Connection was cancelled")
}
if (isDatagram) (for {
f map { channel readyChannel NettyFutureBridge(bootstrap.connect(addressToSocketAddress(remoteAddress))) map {
channel.getRemoteAddress match { channel
if (EnableSsl)
blocking {
channel.getPipeline.get[SslHandler](classOf[SslHandler]).handshake().awaitUninterruptibly()
}
if (!isDatagram) channel.setReadable(false)
channel
}
handle if (isDatagram)
Future {
readyChannel.getRemoteAddress match {
case addr: InetSocketAddress case addr: InetSocketAddress
val handle = new UdpAssociationHandle(localAddress, remoteAddress, channel, NettyTransport.this) val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this)
statusPromise.success(handle) handle.readHandlerPromise.future.onSuccess {
handle.readHandlerPromise.future.onSuccess { case listener udpConnectionTable.put(addr, listener) } case listener udpConnectionTable.put(addr, listener)
}
handle
case unknown throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}") case unknown throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}")
} }
} }
else f else
} catch { readyChannel.getPipeline.get[ClientHandler](classOf[ClientHandler]).statusFuture
case e @ (_: UnknownHostException | _: SecurityException | _: IllegalArgumentException) } yield handle) recover {
Future.failed(InvalidAssociationException("Invalid association ", e)) case c: CancellationException throw new NettyTransportException("Connection was cancelled") with NoStackTrace
case NonFatal(e) case NonFatal(t) throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
Future.failed(e)
}) onFailure {
case t: ConnectException statusPromise failure new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
case t statusPromise failure t
} }
statusPromise.future
} }
} }

View file

@ -12,6 +12,7 @@ import java.net.InetSocketAddress
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer } import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
import org.jboss.netty.channel._ import org.jboss.netty.channel._
import scala.concurrent.{ Future, Promise } import scala.concurrent.{ Future, Promise }
import scala.util.{ Success, Failure }
private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] { private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] {
override def initialValue(channel: Channel): Option[HandleEventListener] = None override def initialValue(channel: Channel): Option[HandleEventListener] = None
@ -30,16 +31,15 @@ private[remote] trait TcpHandlers extends CommonHandlers {
override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle = override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle =
new TcpAssociationHandle(localAddress, remoteAddress, channel) new TcpAssociationHandle(localAddress, remoteAddress, channel)
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
notifyListener(e.getChannel, Disassociated) notifyListener(e.getChannel, Disassociated)
}
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) { override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array() val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array()
if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes))) if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes)))
} }
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent) { override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
notifyListener(e.getChannel, Disassociated) notifyListener(e.getChannel, Disassociated)
e.getChannel.close() // No graceful close here e.getChannel.close() // No graceful close here
} }
@ -48,18 +48,16 @@ private[remote] trait TcpHandlers extends CommonHandlers {
private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener]) private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener])
extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers { extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initInbound(e.getChannel, e.getChannel.getRemoteAddress, null) initInbound(e.getChannel, e.getChannel.getRemoteAddress, null)
}
} }
private[remote] class TcpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle]) private[remote] class TcpClientHandler(_transport: NettyTransport)
extends ClientHandler(_transport, _statusPromise) with TcpHandlers { extends ClientHandler(_transport) with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) { override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null) initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null)
}
} }

View file

@ -53,8 +53,7 @@ private[remote] class UdpServerHandler(_transport: NettyTransport, _associationL
initInbound(channel, remoteSocketAddress, msg) initInbound(channel, remoteSocketAddress, msg)
} }
private[remote] class UdpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle]) private[remote] class UdpClientHandler(_transport: NettyTransport) extends ClientHandler(_transport) with UdpHandlers {
extends ClientHandler(_transport, _statusPromise) with UdpHandlers {
override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit =
initOutbound(channel, remoteSocketAddress, msg) initOutbound(channel, remoteSocketAddress, msg)