Added blocking handshake for outbound SSL connections #2833

- Eliminated "Promise passing style" and using future composition wherever possible
This commit is contained in:
Endre Sándor Varga 2013-01-21 12:40:42 +01:00
parent 50fb495e56
commit 580a2484a9
3 changed files with 51 additions and 51 deletions

View file

@ -23,8 +23,8 @@ import org.jboss.netty.channel.socket.nio.{ NioDatagramChannelFactory, NioServer
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.ssl.SslHandler
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
import scala.concurrent.{ ExecutionContext, Promise, Future }
import scala.util.Try
import scala.concurrent.{ ExecutionContext, Promise, Future, blocking }
import scala.util.{ Failure, Success, Try }
import util.control.{ NoStackTrace, NonFatal }
object NettyTransportSettings {
@ -167,12 +167,12 @@ abstract class ServerHandler(protected final val transport: NettyTransport,
}
abstract class ClientHandler(protected final val transport: NettyTransport,
private final val statusPromise: Promise[AssociationHandle])
abstract class ClientHandler(protected final val transport: NettyTransport)
extends NettyClientHelpers with CommonHandlers {
final protected val statusPromise = Promise[AssociationHandle]()
def statusFuture = statusPromise.future
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
channel.setReadable(false)
init(channel, remoteSocketAddress, msg)(statusPromise.success)
}
@ -256,7 +256,7 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
private def sslHandler(isClient: Boolean): SslHandler = {
val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
if (isClient) handler.setIssueHandshake(true)
handler.setCloseOnSSLException(true)
handler
}
@ -271,12 +271,13 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
}
}
private def clientPipelineFactory(statusPromise: Promise[AssociationHandle]): ChannelPipelineFactory = new ChannelPipelineFactory {
private val clientPipelineFactory: ChannelPipelineFactory =
new ChannelPipelineFactory {
override def getPipeline: ChannelPipeline = {
val pipeline = newPipeline
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, statusPromise)
else new TcpClientHandler(NettyTransport.this, statusPromise)
val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this)
else new TcpClientHandler(NettyTransport.this)
pipeline.addLast("clienthandler", handler)
pipeline
}
@ -302,8 +303,8 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
case Udp setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
}
private def outboundBootstrap(statusPromise: Promise[AssociationHandle]): ClientBootstrap = {
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(statusPromise))
private def outboundBootstrap: ClientBootstrap = {
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory)
bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
bootstrap
}
@ -346,34 +347,36 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
override def associate(remoteAddress: Address): Future[AssociationHandle] = {
if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
else {
val statusPromise = Promise[AssociationHandle]()
(try {
val f = NettyFutureBridge(outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress))) recover {
case c: CancellationException throw new NettyTransportException("Connection was cancelled")
}
val bootstrap: ClientBootstrap = outboundBootstrap
if (isDatagram)
f map { channel
channel.getRemoteAddress match {
(for {
readyChannel NettyFutureBridge(bootstrap.connect(addressToSocketAddress(remoteAddress))) map {
channel
if (EnableSsl)
blocking {
channel.getPipeline.get[SslHandler](classOf[SslHandler]).handshake().awaitUninterruptibly()
}
if (!isDatagram) channel.setReadable(false)
channel
}
handle if (isDatagram)
Future {
readyChannel.getRemoteAddress match {
case addr: InetSocketAddress
val handle = new UdpAssociationHandle(localAddress, remoteAddress, channel, NettyTransport.this)
statusPromise.success(handle)
handle.readHandlerPromise.future.onSuccess { case listener udpConnectionTable.put(addr, listener) }
val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this)
handle.readHandlerPromise.future.onSuccess {
case listener udpConnectionTable.put(addr, listener)
}
handle
case unknown throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}")
}
}
else f
} catch {
case e @ (_: UnknownHostException | _: SecurityException | _: IllegalArgumentException)
Future.failed(InvalidAssociationException("Invalid association ", e))
case NonFatal(e)
Future.failed(e)
}) onFailure {
case t: ConnectException statusPromise failure new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
case t statusPromise failure t
else
readyChannel.getPipeline.get[ClientHandler](classOf[ClientHandler]).statusFuture
} yield handle) recover {
case c: CancellationException throw new NettyTransportException("Connection was cancelled") with NoStackTrace
case NonFatal(t) throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
}
statusPromise.future
}
}

View file

@ -12,6 +12,7 @@ import java.net.InetSocketAddress
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
import org.jboss.netty.channel._
import scala.concurrent.{ Future, Promise }
import scala.util.{ Success, Failure }
private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] {
override def initialValue(channel: Channel): Option[HandleEventListener] = None
@ -30,16 +31,15 @@ private[remote] trait TcpHandlers extends CommonHandlers {
override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle =
new TcpAssociationHandle(localAddress, remoteAddress, channel)
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
notifyListener(e.getChannel, Disassociated)
}
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) {
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array()
if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes)))
}
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent) {
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
notifyListener(e.getChannel, Disassociated)
e.getChannel.close() // No graceful close here
}
@ -48,18 +48,16 @@ private[remote] trait TcpHandlers extends CommonHandlers {
private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener])
extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initInbound(e.getChannel, e.getChannel.getRemoteAddress, null)
}
}
private[remote] class TcpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle])
extends ClientHandler(_transport, _statusPromise) with TcpHandlers {
private[remote] class TcpClientHandler(_transport: NettyTransport)
extends ClientHandler(_transport) with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null)
}
}

View file

@ -53,8 +53,7 @@ private[remote] class UdpServerHandler(_transport: NettyTransport, _associationL
initInbound(channel, remoteSocketAddress, msg)
}
private[remote] class UdpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle])
extends ClientHandler(_transport, _statusPromise) with UdpHandlers {
private[remote] class UdpClientHandler(_transport: NettyTransport) extends ClientHandler(_transport) with UdpHandlers {
override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit =
initOutbound(channel, remoteSocketAddress, msg)