Migrate the classic transport to Netty 4 without CVEs (#643)

* !remoting Update classic transport from Netty 3 to netty4

* use lambda

* =sbt Update Netty4 version to 4.1.97.final

* Reduce allocation in ChannelLocalActor.

* Remove the duplicated code in NettyHelpers.
This commit is contained in:
kerr 2023-09-16 02:24:28 +08:00 committed by GitHub
parent d0b9c43bb5
commit 46edc51a82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 224 additions and 253 deletions

View file

@ -48,4 +48,5 @@ SortImports.blocks = [
"com.sun."
"org.apache.pekko."
"org.reactivestreams."
"io.netty."
]

View file

@ -28,8 +28,7 @@ object Dependencies {
// needs to be inline with the aeron version, check
// https://github.com/real-logic/aeron/blob/1.x.y/build.gradle
val agronaVersion = "1.19.2"
val nettyVersion = "3.10.6.Final"
val netty4Version = "4.1.96.Final"
val nettyVersion = "4.1.97.Final"
val protobufJavaVersion = "3.19.6"
val logbackVersion = "1.2.11"
@ -60,9 +59,8 @@ object Dependencies {
// Compile
val config = "com.typesafe" % "config" % "1.4.2"
val netty = "io.netty" % "netty" % nettyVersion
val `netty-transport` = "io.netty" % "netty-transport" % netty4Version
val `netty-handler` = "io.netty" % "netty-handler" % netty4Version
val `netty-transport` = "io.netty" % "netty-transport" % nettyVersion
val `netty-handler` = "io.netty" % "netty-handler" % nettyVersion
val scalaReflect: ScalaVersionDependentModuleID =
ScalaVersionDependentModuleID.versioned("org.scala-lang" % "scala-reflect" % _)
@ -278,7 +276,7 @@ object Dependencies {
Compile.slf4jApi,
TestDependencies.scalatest.value)
val remoteDependencies = Seq(netty, aeronDriver, aeronClient)
val remoteDependencies = Seq(`netty-transport`, `netty-handler`, aeronDriver, aeronClient)
val remoteOptionalDependencies = remoteDependencies.map(_ % "optional")
val remote = l ++= Seq(

View file

@ -14,14 +14,14 @@
package org.apache.pekko.remote.classic
import com.typesafe.config.ConfigFactory
import org.jboss.netty.channel.ChannelException
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import org.apache.pekko
import pekko.actor.ActorSystem
import pekko.testkit.SocketUtil
import java.net.BindException
class RemotingFailedToBindSpec extends AnyWordSpec with Matchers {
"an ActorSystem" must {
@ -43,10 +43,10 @@ class RemotingFailedToBindSpec extends AnyWordSpec with Matchers {
""".stripMargin)
val as = ActorSystem("RemotingFailedToBindSpec", config)
try {
val ex = intercept[ChannelException] {
val ex = intercept[BindException] {
ActorSystem("BindTest2", config)
}
ex.getMessage should startWith("Failed to bind")
ex.getMessage should startWith("Address already in use")
} finally {
as.terminate()
}

View file

@ -0,0 +1,4 @@
#migrate the classic transport to Netty4
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.pekko.remote.transport.netty.NettyFutureBridge.apply")
ProblemFilters.exclude[MissingClassProblem]("org.apache.pekko.remote.transport.netty.ChannelLocalActor")
ProblemFilters.exclude[MissingClassProblem]("org.apache.pekko.remote.transport.netty.ChannelLocalActor$")

View file

@ -281,7 +281,7 @@ private[pekko] class RemoteActorRefProvider(
private def checkNettyOnClassPath(system: ActorSystemImpl): Unit = {
checkClassOrThrow(
system,
"org.jboss.netty.channel.Channel",
"io.netty.channel.Channel",
"Classic",
"Netty",
"https://pekko.apache.org/docs/pekko/current/remoting.html")

View file

@ -15,34 +15,36 @@ package org.apache.pekko.remote.transport.netty
import java.nio.channels.ClosedChannelException
import scala.annotation.nowarn
import scala.util.control.NonFatal
import org.jboss.netty.channel._
import org.apache.pekko
import pekko.PekkoException
import pekko.util.unused
import io.netty.buffer.ByteBuf
import io.netty.channel.{ ChannelHandlerContext, SimpleChannelInboundHandler }
/**
* INTERNAL API
*/
private[netty] trait NettyHelpers {
protected def onConnect(@unused ctx: ChannelHandlerContext, @unused e: ChannelStateEvent): Unit = ()
protected def onConnect(@unused ctx: ChannelHandlerContext): Unit = ()
protected def onDisconnect(@unused ctx: ChannelHandlerContext, @unused e: ChannelStateEvent): Unit = ()
protected def onDisconnect(@unused ctx: ChannelHandlerContext): Unit = ()
protected def onOpen(@unused ctx: ChannelHandlerContext, @unused e: ChannelStateEvent): Unit = ()
protected def onOpen(@unused ctx: ChannelHandlerContext): Unit = ()
protected def onMessage(@unused ctx: ChannelHandlerContext, @unused e: MessageEvent): Unit = ()
protected def onMessage(@unused ctx: ChannelHandlerContext, @unused msg: ByteBuf): Unit = ()
protected def onException(@unused ctx: ChannelHandlerContext, @unused e: ExceptionEvent): Unit = ()
protected def onException(@unused ctx: ChannelHandlerContext, @unused e: Throwable): Unit = ()
final protected def transformException(ctx: ChannelHandlerContext, ev: ExceptionEvent): Unit = {
val cause = if (ev.getCause ne null) ev.getCause else new PekkoException("Unknown cause")
final protected def transformException(ctx: ChannelHandlerContext, ex: Throwable): Unit = {
val cause = if (ex ne null) ex else new PekkoException("Unknown cause")
cause match {
case _: ClosedChannelException => // Ignore
case null | NonFatal(_) => onException(ctx, ev)
case null | NonFatal(_) => onException(ctx, ex)
case e: Throwable => throw e // Rethrow fatals
}
}
@ -51,54 +53,33 @@ private[netty] trait NettyHelpers {
/**
* INTERNAL API
*/
private[netty] trait NettyServerHelpers extends SimpleChannelUpstreamHandler with NettyHelpers {
final override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
super.messageReceived(ctx, e)
onMessage(ctx, e)
private[netty] abstract class NettyChannelHandlerAdapter extends SimpleChannelInboundHandler[ByteBuf]
with NettyHelpers {
final override def channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf): Unit = {
onMessage(ctx, msg)
}
final override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = transformException(ctx, e)
final override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelConnected(ctx, e)
onConnect(ctx, e)
@nowarn("msg=deprecated")
final override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
transformException(ctx, cause)
}
final override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelOpen(ctx, e)
onOpen(ctx, e)
final override def channelActive(ctx: ChannelHandlerContext): Unit = {
onOpen(ctx)
onConnect(ctx)
}
final override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelDisconnected(ctx, e)
onDisconnect(ctx, e)
final override def channelInactive(ctx: ChannelHandlerContext): Unit = {
onDisconnect(ctx)
}
}
/**
* INTERNAL API
*/
private[netty] trait NettyClientHelpers extends SimpleChannelHandler with NettyHelpers {
final override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
super.messageReceived(ctx, e)
onMessage(ctx, e)
}
private[netty] trait NettyServerHelpers extends NettyChannelHandlerAdapter
final override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = transformException(ctx, e)
final override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelConnected(ctx, e)
onConnect(ctx, e)
}
final override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelOpen(ctx, e)
onOpen(ctx, e)
}
final override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
super.channelDisconnected(ctx, e)
onDisconnect(ctx, e)
}
}
/**
* INTERNAL API
*/
private[netty] trait NettyClientHelpers extends NettyChannelHandlerAdapter

View file

@ -14,13 +14,16 @@
package org.apache.pekko.remote.transport.netty
import scala.annotation.nowarn
import com.typesafe.config.Config
import org.jboss.netty.handler.ssl.SslHandler
import com.typesafe.config.Config
import org.apache.pekko
import pekko.japi.Util._
import pekko.util.ccompat._
import io.netty.channel.Channel
import io.netty.handler.ssl.SslHandler
import io.netty.util.concurrent.Future
/**
* INTERNAL API
*/
@ -64,6 +67,12 @@ private[pekko] object NettySSLSupport {
val sslEngine =
if (isClient) sslEngineProvider.createClientSSLEngine()
else sslEngineProvider.createServerSSLEngine()
new SslHandler(sslEngine)
val handler = new SslHandler(sslEngine)
handler.handshakeFuture().addListener((future: Future[Channel]) => {
if (!future.isSuccess) {
handler.closeOutbound().channel().close()
}
})
handler
}
}

View file

@ -14,7 +14,7 @@
package org.apache.pekko.remote.transport.netty
import java.net.{ InetAddress, InetSocketAddress, SocketAddress }
import java.util.concurrent.{ CancellationException, Executors }
import java.util.concurrent.CancellationException
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.nowarn
@ -24,23 +24,6 @@ import scala.util.Try
import scala.util.control.{ NoStackTrace, NonFatal }
import com.typesafe.config.Config
import org.jboss.netty.bootstrap.{ Bootstrap, ClientBootstrap, ServerBootstrap }
import org.jboss.netty.buffer.{ ChannelBuffer, ChannelBuffers }
import org.jboss.netty.channel._
import org.jboss.netty.channel.group.{
ChannelGroup,
ChannelGroupFuture,
ChannelGroupFutureListener,
DefaultChannelGroup
}
import org.jboss.netty.channel.socket.nio.{
NioClientSocketChannelFactory,
NioServerSocketChannelFactory,
NioWorkerPool
}
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.ssl.SslHandler
import org.jboss.netty.util.HashedWheelTimer
import org.apache.pekko
import pekko.ConfigurationException
import pekko.OnlyCauseStackTrace
@ -49,43 +32,57 @@ import pekko.dispatch.ThreadPoolConfig
import pekko.event.Logging
import pekko.remote.RARP
import pekko.remote.transport.{ AssociationHandle, Transport }
import pekko.remote.transport.AssociationHandle.HandleEventListener
import pekko.remote.transport.Transport._
import pekko.util.Helpers.Requiring
import pekko.util.{ Helpers, OptionVal }
import AssociationHandle.HandleEventListener
import Transport._
import Helpers.Requiring
import io.netty.bootstrap.{ Bootstrap => ClientBootstrap, ServerBootstrap }
import io.netty.buffer.Unpooled
import io.netty.channel.{
Channel,
ChannelFuture,
ChannelHandlerContext,
ChannelInitializer,
ChannelOption,
ChannelPipeline
}
import io.netty.channel.group.{ ChannelGroup, ChannelGroupFuture, ChannelMatchers, DefaultChannelGroup }
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.{ NioServerSocketChannel, NioSocketChannel }
import io.netty.handler.codec.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import io.netty.handler.ssl.SslHandler
import io.netty.util.concurrent.GlobalEventExecutor
@deprecated("Classic remoting is deprecated, use Artery", "Akka 2.6.0")
object NettyFutureBridge {
def apply(nettyFuture: ChannelFuture): Future[Channel] = {
val p = Promise[Channel]()
nettyFuture.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture): Unit =
p.complete(
Try(
if (future.isSuccess) future.getChannel
else if (future.isCancelled) throw new CancellationException
else throw future.getCause))
})
nettyFuture.addListener((future: ChannelFuture) =>
p.complete(
Try(
if (future.isSuccess) future.channel()
else if (future.isCancelled) throw new CancellationException
else throw future.cause())))
p.future
}
def apply(nettyFuture: ChannelGroupFuture): Future[ChannelGroup] = {
import pekko.util.ccompat.JavaConverters._
val p = Promise[ChannelGroup]()
nettyFuture.addListener(new ChannelGroupFutureListener {
def operationComplete(future: ChannelGroupFuture): Unit =
p.complete(
Try(
if (future.isCompleteSuccess) future.getGroup
else
throw future.iterator.asScala
.collectFirst {
case f if f.isCancelled => new CancellationException
case f if !f.isSuccess => f.getCause
}
.getOrElse(new IllegalStateException(
"Error reported in ChannelGroupFuture, but no error found in individual futures."))))
})
nettyFuture.addListener((future: ChannelGroupFuture) =>
p.complete(
Try(
if (future.isSuccess) future.group()
else
throw future.iterator.asScala
.collectFirst {
case f if f.isCancelled => new CancellationException
case f if !f.isSuccess => f.cause()
}
.getOrElse(new IllegalStateException(
"Error reported in ChannelGroupFuture, but no error found in individual futures.")))))
p.future
}
}
@ -182,27 +179,6 @@ class NettyTransportSettings(config: Config) {
config.getInt("pool-size-min"),
config.getDouble("pool-size-factor"),
config.getInt("pool-size-max"))
// Check Netty version >= 3.10.6
{
val nettyVersion = org.jboss.netty.util.Version.ID
def throwInvalidNettyVersion(): Nothing = {
throw new IllegalArgumentException(
"pekko-remote with the Netty transport requires Netty version 3.10.6 or " +
s"later. Version [$nettyVersion] is on the class path. Issue https://github.com/netty/netty/pull/4739 " +
"may cause messages to not be delivered.")
}
try {
val segments: Array[String] = nettyVersion.split("[.-]")
if (segments.length < 3 || segments(0).toInt != 3 || segments(1).toInt != 10 || segments(2).toInt < 6)
throwInvalidNettyVersion()
} catch {
case _: NumberFormatException =>
throwInvalidNettyVersion()
}
}
}
/**
@ -212,25 +188,24 @@ class NettyTransportSettings(config: Config) {
private[netty] trait CommonHandlers extends NettyHelpers {
protected val transport: NettyTransport
final override def onOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
transport.channelGroup.add(e.getChannel)
override protected def onOpen(ctx: ChannelHandlerContext): Unit = {
transport.channelGroup.add(ctx.channel())
}
protected def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle
protected def registerListener(
channel: Channel,
listener: HandleEventListener,
msg: ChannelBuffer,
remoteSocketAddress: InetSocketAddress): Unit
final protected def init(
channel: Channel,
remoteSocketAddress: SocketAddress,
remoteAddress: Address,
msg: ChannelBuffer)(op: AssociationHandle => Any): Unit = {
remoteAddress: Address)(op: AssociationHandle => Any): Unit = {
import transport._
NettyTransport.addressFromSocketAddress(
channel.getLocalAddress,
channel.localAddress(),
schemeIdentifier,
system.name,
Some(settings.Hostname),
@ -238,8 +213,8 @@ private[netty] trait CommonHandlers extends NettyHelpers {
case Some(localAddress) =>
val handle = createHandle(channel, localAddress, remoteAddress)
handle.readHandlerPromise.future.foreach { listener =>
registerListener(channel, listener, msg, remoteSocketAddress.asInstanceOf[InetSocketAddress])
channel.setReadable(true)
registerListener(channel, listener, remoteSocketAddress.asInstanceOf[InetSocketAddress])
channel.config().setAutoRead(true)
}
op(handle)
@ -260,8 +235,8 @@ private[netty] abstract class ServerHandler(
import transport.executionContext
final protected def initInbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
channel.setReadable(false)
final protected def initInbound(channel: Channel, remoteSocketAddress: SocketAddress): Unit = {
channel.config.setAutoRead(false)
associationListenerFuture.foreach { listener =>
val remoteAddress = NettyTransport
.addressFromSocketAddress(
@ -272,7 +247,7 @@ private[netty] abstract class ServerHandler(
port = None)
.getOrElse(throw new NettyTransportException(
s"Unknown inbound remote address type [${remoteSocketAddress.getClass.getName}]"))
init(channel, remoteSocketAddress, remoteAddress, msg) { a =>
init(channel, remoteSocketAddress, remoteAddress) { a =>
listener.notify(InboundAssociation(a))
}
}
@ -288,10 +263,10 @@ private[netty] abstract class ClientHandler(protected final val transport: Netty
extends NettyClientHelpers
with CommonHandlers {
final protected val statusPromise = Promise[AssociationHandle]()
def statusFuture = statusPromise.future
def statusFuture: Future[AssociationHandle] = statusPromise.future
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
init(channel, remoteSocketAddress, remoteAddress, msg)(statusPromise.success)
final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress): Unit = {
init(channel, remoteSocketAddress, remoteAddress)(statusPromise.success)
}
}
@ -304,9 +279,9 @@ private[transport] object NettyTransport {
val FrameLengthFieldLength = 4
def gracefulClose(channel: Channel)(implicit ec: ExecutionContext): Unit = {
@nowarn("msg=deprecated")
def always(c: ChannelFuture) = NettyFutureBridge(c).recover { case _ => c.getChannel }
def always(c: ChannelFuture): Future[Channel] = NettyFutureBridge(c).recover { case _ => c.channel() }
for {
_ <- always { channel.write(ChannelBuffers.buffer(0)) } // Force flush by waiting on a final dummy write
_ <- always { channel.writeAndFlush(Unpooled.EMPTY_BUFFER) } // Force flush by waiting on a final dummy write
_ <- always { channel.disconnect() }
} channel.close()
}
@ -358,8 +333,10 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
private val log = Logging.withMarker(system, classOf[NettyTransport])
private def createExecutorService() =
UseDispatcherForIo.map(system.dispatchers.lookup).getOrElse(Executors.newCachedThreadPool(system.threadFactory))
private def createEventLoopGroup(nThreadCount: Int): NioEventLoopGroup =
UseDispatcherForIo.map(system.dispatchers.lookup)
.map(executor => new NioEventLoopGroup(0, executor))
.getOrElse(new NioEventLoopGroup(nThreadCount, system.threadFactory))
/*
* Be aware, that the close() method of DefaultChannelGroup is racy, because it uses an iterator over a ConcurrentHashMap.
@ -369,25 +346,17 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
*/
val channelGroup = new DefaultChannelGroup(
"pekko-netty-transport-driver-channelgroup-" +
uniqueIdCounter.getAndIncrement)
uniqueIdCounter.getAndIncrement,
GlobalEventExecutor.INSTANCE)
private val clientChannelFactory: ChannelFactory = {
val boss, worker = createExecutorService()
new NioClientSocketChannelFactory(
boss,
1,
new NioWorkerPool(worker, ClientSocketWorkerPoolSize),
new HashedWheelTimer(system.threadFactory))
}
private val clientEventLoopGroup = createEventLoopGroup(ClientSocketWorkerPoolSize + 1)
private val serverChannelFactory: ChannelFactory = {
val boss, worker = createExecutorService()
// This does not create a HashedWheelTimer internally
new NioServerSocketChannelFactory(boss, worker, ServerSocketWorkerPoolSize)
}
private val serverEventLoopParentGroup = createEventLoopGroup(0)
private def newPipeline: DefaultChannelPipeline = {
val pipeline = new DefaultChannelPipeline
private val serverEventLoopChildGroup = createEventLoopGroup(ServerSocketWorkerPoolSize)
private def newPipeline(channel: Channel): ChannelPipeline = {
val pipeline = channel.pipeline()
pipeline.addLast(
"FrameDecoder",
new LengthFieldBasedFrameDecoder(
@ -420,62 +389,74 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
private def sslHandler(isClient: Boolean): SslHandler = {
sslEngineProvider match {
case OptionVal.Some(sslProvider) =>
val handler = NettySSLSupport(sslProvider, isClient)
handler.setCloseOnSSLException(true)
handler
NettySSLSupport(sslProvider, isClient)
case _ =>
throw new IllegalStateException("Expected enable-ssl=on")
}
}
private val serverPipelineFactory: ChannelPipelineFactory = new ChannelPipelineFactory {
override def getPipeline: ChannelPipeline = {
val pipeline = newPipeline
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = false))
val handler = new TcpServerHandler(NettyTransport.this, associationListenerPromise.future, log)
pipeline.addLast("ServerHandler", handler)
pipeline
}
private val serverPipelineInitializer: ChannelInitializer[SocketChannel] = (ch: SocketChannel) => {
val pipeline = newPipeline(ch)
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = false))
val handler = new TcpServerHandler(NettyTransport.this, associationListenerPromise.future, log)
pipeline.addLast("ServerHandler", handler)
}
private def clientPipelineFactory(remoteAddress: Address): ChannelPipelineFactory =
new ChannelPipelineFactory {
override def getPipeline: ChannelPipeline = {
val pipeline = newPipeline
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
val handler = new TcpClientHandler(NettyTransport.this, remoteAddress, log)
pipeline.addLast("clienthandler", handler)
pipeline
}
private def clientPipelineInitializer(remoteAddress: Address): ChannelInitializer[SocketChannel] =
(ch: SocketChannel) => {
val pipeline = newPipeline(ch)
if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
val handler = new TcpClientHandler(NettyTransport.this, remoteAddress, log)
pipeline.addLast("clienthandler", handler)
}
private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
bootstrap.setPipelineFactory(pipelineFactory)
bootstrap.setOption("backlog", settings.Backlog)
bootstrap.setOption("child.tcpNoDelay", settings.TcpNodelay)
bootstrap.setOption("child.keepAlive", settings.TcpKeepalive)
bootstrap.setOption("reuseAddress", settings.TcpReuseAddr)
settings.ReceiveBufferSize.foreach(sz => bootstrap.setOption("receiveBufferSize", sz))
settings.SendBufferSize.foreach(sz => bootstrap.setOption("sendBufferSize", sz))
settings.WriteBufferHighWaterMark.foreach(sz => bootstrap.setOption("writeBufferHighWaterMark", sz))
settings.WriteBufferLowWaterMark.foreach(sz => bootstrap.setOption("writeBufferLowWaterMark", sz))
private val inboundBootstrap: ServerBootstrap = {
val bootstrap = new ServerBootstrap()
bootstrap.group(serverEventLoopParentGroup, serverEventLoopChildGroup)
bootstrap.channel(classOf[NioServerSocketChannel])
bootstrap.childHandler(serverPipelineInitializer)
// DO NOT AUTO READ
bootstrap.option[java.lang.Boolean](ChannelOption.AUTO_READ, false)
bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, settings.Backlog)
bootstrap.option[java.lang.Boolean](ChannelOption.SO_REUSEADDR, settings.TcpReuseAddr)
// DO NOT AUTO READ
bootstrap.childOption[java.lang.Boolean](ChannelOption.AUTO_READ, false)
bootstrap.childOption[java.lang.Boolean](ChannelOption.TCP_NODELAY, settings.TcpNodelay)
bootstrap.childOption[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, settings.TcpKeepalive)
settings.ReceiveBufferSize.foreach(sz => bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, sz))
settings.SendBufferSize.foreach(sz => bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sz))
settings.WriteBufferHighWaterMark.filter(_ > 0).foreach(sz =>
bootstrap.childOption[java.lang.Integer](ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, sz))
settings.WriteBufferLowWaterMark.filter(_ > 0).foreach(sz =>
bootstrap.childOption[java.lang.Integer](ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, sz))
bootstrap
}
private val inboundBootstrap: ServerBootstrap = {
setupBootstrap(new ServerBootstrap(serverChannelFactory), serverPipelineFactory)
}
private def outboundBootstrap(remoteAddress: Address): ClientBootstrap = {
val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(remoteAddress))
bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
bootstrap.setOption("tcpNoDelay", settings.TcpNodelay)
bootstrap.setOption("keepAlive", settings.TcpKeepalive)
settings.ReceiveBufferSize.foreach(sz => bootstrap.setOption("receiveBufferSize", sz))
settings.SendBufferSize.foreach(sz => bootstrap.setOption("sendBufferSize", sz))
settings.WriteBufferHighWaterMark.foreach(sz => bootstrap.setOption("writeBufferHighWaterMark", sz))
settings.WriteBufferLowWaterMark.foreach(sz => bootstrap.setOption("writeBufferLowWaterMark", sz))
val bootstrap = new ClientBootstrap()
bootstrap.group(clientEventLoopGroup)
bootstrap.handler(clientPipelineInitializer(remoteAddress))
bootstrap.channel(classOf[NioSocketChannel])
// DO NOT AUTO READ
bootstrap.option[java.lang.Boolean](ChannelOption.AUTO_READ, false)
bootstrap.option[java.lang.Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.ConnectionTimeout.toMillis.toInt)
bootstrap.option[java.lang.Boolean](ChannelOption.TCP_NODELAY, settings.TcpNodelay)
bootstrap.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, settings.TcpKeepalive)
settings.ReceiveBufferSize.foreach(sz => bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, sz))
settings.SendBufferSize.foreach(sz => bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sz))
settings.WriteBufferHighWaterMark.filter(_ > 0).foreach(sz =>
bootstrap.option[java.lang.Integer](ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, sz))
settings.WriteBufferLowWaterMark.filter(_ > 0).foreach(sz =>
bootstrap.option[java.lang.Integer](ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, sz))
//
bootstrap
}
@ -496,10 +477,10 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
address <- addressToSocketAddress(Address("", "", settings.BindHostname, bindPort))
} yield {
try {
val newServerChannel = inboundBootstrap.bind(address)
val newServerChannel = inboundBootstrap.bind(address).sync().channel()
// Block reads until a handler actor is registered
newServerChannel.setReadable(false)
newServerChannel.config().setAutoRead(false)
channelGroup.add(newServerChannel)
serverChannel = newServerChannel
@ -508,26 +489,26 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
val port = if (settings.PortSelector == 0) None else Some(settings.PortSelector)
addressFromSocketAddress(
newServerChannel.getLocalAddress,
newServerChannel.localAddress(),
schemeIdentifier,
system.name,
Some(settings.Hostname),
port) match {
case Some(address) =>
addressFromSocketAddress(newServerChannel.getLocalAddress, schemeIdentifier, system.name, None,
addressFromSocketAddress(newServerChannel.localAddress(), schemeIdentifier, system.name, None,
None) match {
case Some(address) => boundTo = address
case None =>
throw new NettyTransportException(
s"Unknown local address type [${newServerChannel.getLocalAddress.getClass.getName}]")
s"Unknown local address type [${newServerChannel.localAddress().getClass.getName}]")
}
associationListenerPromise.future.foreach { _ =>
newServerChannel.setReadable(true)
newServerChannel.config().setAutoRead(true)
}
(address, associationListenerPromise)
case None =>
throw new NettyTransportException(
s"Unknown local address type [${newServerChannel.getLocalAddress.getClass.getName}]")
s"Unknown local address type [${newServerChannel.localAddress().getClass.getName}]")
}
} catch {
case NonFatal(e) => {
@ -545,7 +526,7 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
private[pekko] def boundAddress = boundTo
override def associate(remoteAddress: Address): Future[AssociationHandle] = {
if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
if (!serverChannel.isActive) Future.failed(new NettyTransportException("Transport is not bound"))
else {
val bootstrap: ClientBootstrap = outboundBootstrap(remoteAddress)
@ -554,12 +535,12 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
readyChannel <- NettyFutureBridge(bootstrap.connect(socketAddress)).map { channel =>
if (EnableSsl)
blocking {
channel.getPipeline.get(classOf[SslHandler]).handshake().awaitUninterruptibly()
channel.pipeline().get(classOf[SslHandler]).handshakeFuture().awaitUninterruptibly()
}
channel.setReadable(false)
channel.config.setAutoRead(false)
channel
}
handle <- readyChannel.getPipeline.get(classOf[ClientHandler]).statusFuture
handle <- readyChannel.pipeline().get(classOf[ClientHandler]).statusFuture
} yield handle).recover {
case _: CancellationException => throw new NettyTransportExceptionNoStack("Connection was cancelled")
case NonFatal(t) =>
@ -576,22 +557,18 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
}
override def shutdown(): Future[Boolean] = {
def always(c: ChannelGroupFuture) = NettyFutureBridge(c).map(_ => true).recover { case _ => false }
def always(c: ChannelGroupFuture): Future[Boolean] = NettyFutureBridge(c).map(_ => true).recover { case _ => false }
for {
// Force flush by trying to write an empty buffer and wait for success
unbindStatus <- always(channelGroup.unbind())
lastWriteStatus <- always(channelGroup.write(ChannelBuffers.buffer(0)))
unbindStatus <- always(channelGroup.close(ChannelMatchers.isServerChannel))
lastWriteStatus <- always(channelGroup.writeAndFlush(Unpooled.EMPTY_BUFFER))
disconnectStatus <- always(channelGroup.disconnect())
closeStatus <- always(channelGroup.close())
} yield {
// Release the selectors, but don't try to kill the dispatcher
if (UseDispatcherForIo.isDefined) {
clientChannelFactory.shutdown()
serverChannelFactory.shutdown()
} else {
clientChannelFactory.releaseExternalResources()
serverChannelFactory.releaseExternalResources()
}
clientEventLoopGroup.shutdownGracefully()
serverEventLoopParentGroup.shutdownGracefully()
serverEventLoopChildGroup.shutdownGracefully()
lastWriteStatus && unbindStatus && disconnectStatus && closeStatus
}

View file

@ -15,11 +15,8 @@ package org.apache.pekko.remote.transport.netty
import java.net.InetSocketAddress
import scala.concurrent.{ Future, Promise }
import scala.annotation.nowarn
import org.jboss.netty.buffer.{ ChannelBuffer, ChannelBuffers }
import org.jboss.netty.channel._
import scala.concurrent.{ Future, Promise }
import org.apache.pekko
import pekko.actor.Address
@ -29,12 +26,12 @@ import pekko.remote.transport.AssociationHandle.{ Disassociated, HandleEvent, Ha
import pekko.remote.transport.Transport.AssociationEventListener
import pekko.util.ByteString
/**
* INTERNAL API
*/
private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEventListener]] {
override def initialValue(channel: Channel): Option[HandleEventListener] = None
def notifyListener(channel: Channel, msg: HandleEvent): Unit = get(channel).foreach { _.notify(msg) }
import io.netty.buffer.{ ByteBuf, ByteBufUtil, Unpooled }
import io.netty.channel.{ Channel, ChannelHandlerContext }
import io.netty.util.AttributeKey
private[remote] object TcpHandlers {
private val LISTENER = AttributeKey.valueOf[HandleEventListener]("listener")
}
/**
@ -43,33 +40,37 @@ private[remote] object ChannelLocalActor extends ChannelLocal[Option[HandleEvent
@nowarn("msg=deprecated")
private[remote] trait TcpHandlers extends CommonHandlers {
protected def log: LoggingAdapter
import ChannelLocalActor._
import TcpHandlers._
override def registerListener(
channel: Channel,
listener: HandleEventListener,
msg: ChannelBuffer,
remoteSocketAddress: InetSocketAddress): Unit =
ChannelLocalActor.set(channel, Some(listener))
remoteSocketAddress: InetSocketAddress): Unit = channel.attr(LISTENER).set(listener)
override def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle =
new TcpAssociationHandle(localAddress, remoteAddress, transport, channel)
override def onDisconnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = {
notifyListener(e.getChannel, Disassociated(AssociationHandle.Unknown))
log.debug("Remote connection to [{}] was disconnected because of {}", e.getChannel.getRemoteAddress, e)
override def onDisconnect(ctx: ChannelHandlerContext): Unit = {
notifyListener(ctx.channel(), Disassociated(AssociationHandle.Unknown))
log.debug("Remote connection to [{}] was disconnected.", ctx.channel().remoteAddress())
}
override def onMessage(ctx: ChannelHandlerContext, e: MessageEvent): Unit = {
val bytes: Array[Byte] = e.getMessage.asInstanceOf[ChannelBuffer].array()
if (bytes.length > 0) notifyListener(e.getChannel, InboundPayload(ByteString(bytes)))
override def onMessage(ctx: ChannelHandlerContext, msg: ByteBuf): Unit = {
val bytes: Array[Byte] = ByteBufUtil.getBytes(msg)
if (bytes.length > 0) notifyListener(ctx.channel(), InboundPayload(ByteString(bytes)))
}
override def onException(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
notifyListener(e.getChannel, Disassociated(AssociationHandle.Unknown))
log.warning("Remote connection to [{}] failed with {}", e.getChannel.getRemoteAddress, e.getCause)
e.getChannel.close() // No graceful close here
override def onException(ctx: ChannelHandlerContext, e: Throwable): Unit = {
notifyListener(ctx.channel(), Disassociated(AssociationHandle.Unknown))
log.warning("Remote connection to [{}] failed with {}", ctx.channel().remoteAddress(), e.getCause)
ctx.channel().close() // No graceful close here
}
private def notifyListener(channel: Channel, event: HandleEvent): Unit = {
val listener = channel.attr(LISTENER).get()
if (listener ne null) {
listener.notify(event)
}
}
}
@ -84,8 +85,8 @@ private[remote] class TcpServerHandler(
extends ServerHandler(_transport, _associationListenerFuture)
with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initInbound(e.getChannel, e.getChannel.getRemoteAddress, null)
override def onConnect(ctx: ChannelHandlerContext): Unit =
initInbound(ctx.channel(), ctx.channel().remoteAddress())
}
@ -97,8 +98,8 @@ private[remote] class TcpClientHandler(_transport: NettyTransport, remoteAddress
extends ClientHandler(_transport, remoteAddress)
with TcpHandlers {
override def onConnect(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit =
initOutbound(e.getChannel, e.getChannel.getRemoteAddress, null)
override def onConnect(ctx: ChannelHandlerContext): Unit =
initOutbound(ctx.channel(), ctx.channel().remoteAddress())
}
@ -118,7 +119,7 @@ private[remote] class TcpAssociationHandle(
override def write(payload: ByteString): Boolean =
if (channel.isWritable && channel.isOpen) {
channel.write(ChannelBuffers.wrappedBuffer(payload.asByteBuffer))
channel.writeAndFlush(Unpooled.wrappedBuffer(payload.asByteBuffer))
true
} else false