Added blocking handshake for outbound SSL connections #2833
- Eliminated "Promise passing style" and using future composition wherever possible
This commit is contained in:
parent
50fb495e56
commit
580a2484a9
3 changed files with 51 additions and 51 deletions
|
|
@ -23,8 +23,8 @@ import org.jboss.netty.channel.socket.nio.{ NioDatagramChannelFactory, NioServer
|
||||||
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
|
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
|
||||||
import org.jboss.netty.handler.ssl.SslHandler
|
import org.jboss.netty.handler.ssl.SslHandler
|
||||||
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
|
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
|
||||||
import scala.concurrent.{ ExecutionContext, Promise, Future }
|
import scala.concurrent.{ ExecutionContext, Promise, Future, blocking }
|
||||||
import scala.util.Try
|
import scala.util.{ Failure, Success, Try }
|
||||||
import util.control.{ NoStackTrace, NonFatal }
|
import util.control.{ NoStackTrace, NonFatal }
|
||||||
|
|
||||||
object NettyTransportSettings {
|
object NettyTransportSettings {
|
||||||
|
|
@ -167,12 +167,12 @@ abstract class ServerHandler(protected final val transport: NettyTransport,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class ClientHandler(protected final val transport: NettyTransport,
|
abstract class ClientHandler(protected final val transport: NettyTransport)
|
||||||
private final val statusPromise: Promise[AssociationHandle])
|
|
||||||
extends NettyClientHelpers with CommonHandlers {
|
extends NettyClientHelpers with CommonHandlers {
|
||||||
|
final protected val statusPromise = Promise[AssociationHandle]()
|
||||||
|
def statusFuture = statusPromise.future
|
||||||
|
|
||||||
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
|
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
|
||||||
channel.setReadable(false)
|
|
||||||
init(channel, remoteSocketAddress, msg)(statusPromise.success)
|
init(channel, remoteSocketAddress, msg)(statusPromise.success)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -256,7 +256,7 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
|
||||||
|
|
||||||
private def sslHandler(isClient: Boolean): SslHandler = {
|
private def sslHandler(isClient: Boolean): SslHandler = {
|
||||||
val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
|
val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
|
||||||
if (isClient) handler.setIssueHandshake(true)
|
handler.setCloseOnSSLException(true)
|
||||||
handler
|
handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -271,16 +271,17 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def clientPipelineFactory(statusPromise: Promise[AssociationHandle]): ChannelPipelineFactory = new ChannelPipelineFactory {
|
private val clientPipelineFactory: ChannelPipelineFactory =
|
||||||
override def getPipeline: ChannelPipeline = {
|
new ChannelPipelineFactory {
|
||||||
val pipeline = newPipeline
|
override def getPipeline: ChannelPipeline = {
|
||||||
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
|
val pipeline = newPipeline
|
||||||
val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, statusPromise)
|
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
|
||||||
else new TcpClientHandler(NettyTransport.this, statusPromise)
|
val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this)
|
||||||
pipeline.addLast("clienthandler", handler)
|
else new TcpClientHandler(NettyTransport.this)
|
||||||
pipeline
|
pipeline.addLast("clienthandler", handler)
|
||||||
|
pipeline
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
|
private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
|
||||||
// FIXME: Expose these settings in configuration
|
// FIXME: Expose these settings in configuration
|
||||||
|
|
@ -302,8 +303,8 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
|
||||||
case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
|
case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def outboundBootstrap(statusPromise: Promise[AssociationHandle]): ClientBootstrap = {
|
private def outboundBootstrap: ClientBootstrap = {
|
||||||
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(statusPromise))
|
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory)
|
||||||
bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
|
bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
|
||||||
bootstrap
|
bootstrap
|
||||||
}
|
}
|
||||||
|
|
@ -346,34 +347,36 @@ class NettyTransport(private val settings: NettyTransportSettings, private val s
|
||||||
override def associate(remoteAddress: Address): Future[AssociationHandle] = {
|
override def associate(remoteAddress: Address): Future[AssociationHandle] = {
|
||||||
if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
|
if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
|
||||||
else {
|
else {
|
||||||
val statusPromise = Promise[AssociationHandle]()
|
val bootstrap: ClientBootstrap = outboundBootstrap
|
||||||
(try {
|
|
||||||
val f = NettyFutureBridge(outboundBootstrap(statusPromise).connect(addressToSocketAddress(remoteAddress))) recover {
|
|
||||||
case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isDatagram)
|
(for {
|
||||||
f map { channel ⇒
|
readyChannel ← NettyFutureBridge(bootstrap.connect(addressToSocketAddress(remoteAddress))) map {
|
||||||
channel.getRemoteAddress match {
|
channel ⇒
|
||||||
|
if (EnableSsl)
|
||||||
|
blocking {
|
||||||
|
channel.getPipeline.get[SslHandler](classOf[SslHandler]).handshake().awaitUninterruptibly()
|
||||||
|
}
|
||||||
|
if (!isDatagram) channel.setReadable(false)
|
||||||
|
channel
|
||||||
|
}
|
||||||
|
handle ← if (isDatagram)
|
||||||
|
Future {
|
||||||
|
readyChannel.getRemoteAddress match {
|
||||||
case addr: InetSocketAddress ⇒
|
case addr: InetSocketAddress ⇒
|
||||||
val handle = new UdpAssociationHandle(localAddress, remoteAddress, channel, NettyTransport.this)
|
val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this)
|
||||||
statusPromise.success(handle)
|
handle.readHandlerPromise.future.onSuccess {
|
||||||
handle.readHandlerPromise.future.onSuccess { case listener ⇒ udpConnectionTable.put(addr, listener) }
|
case listener ⇒ udpConnectionTable.put(addr, listener)
|
||||||
|
}
|
||||||
|
handle
|
||||||
case unknown ⇒ throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}")
|
case unknown ⇒ throw new NettyTransportException(s"Unknown remote address type ${unknown.getClass}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else f
|
else
|
||||||
} catch {
|
readyChannel.getPipeline.get[ClientHandler](classOf[ClientHandler]).statusFuture
|
||||||
case e @ (_: UnknownHostException | _: SecurityException | _: IllegalArgumentException) ⇒
|
} yield handle) recover {
|
||||||
Future.failed(InvalidAssociationException("Invalid association ", e))
|
case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled") with NoStackTrace
|
||||||
case NonFatal(e) ⇒
|
case NonFatal(t) ⇒ throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
|
||||||
Future.failed(e)
|
|
||||||
}) onFailure {
|
|
||||||
case t: ConnectException ⇒ statusPromise failure new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
|
|
||||||
case t ⇒ statusPromise failure t
|
|
||||||
}
|
}
|
||||||
|
|
||||||
statusPromise.future
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import java.net.InetSocketAddress
|
||||||
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
|
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
|
||||||
import org.jboss.netty.channel._
|
import org.jboss.netty.channel._
|
||||||
import scala.concurrent.{ Future, Promise }
|
import scala.concurrent.{ Future, Promise }
|
||||||
|
import scala.util.{ Success, Failure }
|
||||||
|
|
||||||
private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] {
|
private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] {
|
||||||
override def initialValue(channel: Channel): Option[HandleEventListener] = None
|
override def initialValue(channel: Channel): Option[HandleEventListener] = None
|
||||||
|
|
@ -30,16 +31,15 @@ private[remote] trait TcpHandlers extends CommonHandlers {
|
||||||
override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle =
|
override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle =
|
||||||
new TcpAssociationHandle(localAddress, remoteAddress, channel)
|
new TcpAssociationHandle(localAddress, remoteAddress, channel)
|
||||||
|
|
||||||
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
|
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
|
||||||
notifyListener(e.getChannel, Disassociated)
|
notifyListener(e.getChannel, Disassociated)
|
||||||
}
|
|
||||||
|
|
||||||
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent) {
|
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
|
||||||
val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array()
|
val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array()
|
||||||
if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes)))
|
if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent) {
|
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
|
||||||
notifyListener(e.getChannel, Disassociated)
|
notifyListener(e.getChannel, Disassociated)
|
||||||
e.getChannel.close() // No graceful close here
|
e.getChannel.close() // No graceful close here
|
||||||
}
|
}
|
||||||
|
|
@ -48,18 +48,16 @@ private[remote] trait TcpHandlers extends CommonHandlers {
|
||||||
private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener])
|
private[remote] class TcpServerHandler(_transport: NettyTransport, _associationListenerFuture: Future[AssociationEventListener])
|
||||||
extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers {
|
extends ServerHandler(_transport, _associationListenerFuture) with TcpHandlers {
|
||||||
|
|
||||||
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
|
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
|
||||||
initInbound(e.getChannel, e.getChannel.getRemoteAddress, null)
|
initInbound(e.getChannel, e.getChannel.getRemoteAddress, null)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private[remote] class TcpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle])
|
private[remote] class TcpClientHandler(_transport: NettyTransport)
|
||||||
extends ClientHandler(_transport, _statusPromise) with TcpHandlers {
|
extends ClientHandler(_transport) with TcpHandlers {
|
||||||
|
|
||||||
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
|
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
|
||||||
initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null)
|
initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,7 @@ private[remote] class UdpServerHandler(_transport: NettyTransport, _associationL
|
||||||
initInbound(channel, remoteSocketAddress, msg)
|
initInbound(channel, remoteSocketAddress, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[remote] class UdpClientHandler(_transport: NettyTransport, _statusPromise: Promise[AssociationHandle])
|
private[remote] class UdpClientHandler(_transport: NettyTransport) extends ClientHandler(_transport) with UdpHandlers {
|
||||||
extends ClientHandler(_transport, _statusPromise) with UdpHandlers {
|
|
||||||
|
|
||||||
override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit =
|
override def initUdp(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit =
|
||||||
initOutbound(channel, remoteSocketAddress, msg)
|
initOutbound(channel, remoteSocketAddress, msg)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue