diff --git a/akka-io/src/main/resources/reference.conf b/akka-io/src/main/resources/reference.conf index 3595cd12f5..a907706b02 100644 --- a/akka-io/src/main/resources/reference.conf +++ b/akka-io/src/main/resources/reference.conf @@ -23,10 +23,12 @@ akka { # these will use one select loop on the selector-dispatcher. nr-of-selectors = 1 - # Maximum number of open channels supported by this TCP module; there is + # Maximum number of open channels supported by this TCP module; there is # no intrinsic general limit, this setting is meant to enable DoS - # protection by limiting the number of concurrently connected clients. - # Set to 0 to disable. + # protection by limiting the number of concurrently connected clients. + # Also note that this is a "soft" limit; in certain cases the implementation + # will accept a few connections more than the number configured here. + # Set to 0 for "unlimited". max-channels = 256000 # The select loop can be used in two modes: diff --git a/akka-io/src/main/scala/akka/io/DirectByteBufferPool.scala b/akka-io/src/main/scala/akka/io/DirectByteBufferPool.scala index 4fd558a162..3a1f00b20e 100644 --- a/akka-io/src/main/scala/akka/io/DirectByteBufferPool.scala +++ b/akka-io/src/main/scala/akka/io/DirectByteBufferPool.scala @@ -32,7 +32,7 @@ class DirectByteBufferPool(bufferSize: Int, maxPoolSize: Int) { @volatile private[this] var pool: List[ByteBuffer] = Nil @volatile private[this] var poolSize: Int = 0 - private[this] def allocate(size: Int): ByteBuffer = + private def allocate(size: Int): ByteBuffer = ByteBuffer.allocateDirect(size) def acquire(size: Int = bufferSize): ByteBuffer = { @@ -44,8 +44,12 @@ class DirectByteBufferPool(bufferSize: Int, maxPoolSize: Int) { if (buf.capacity() <= bufferSize && poolSize < maxPoolSize) addBufferToPool(buf) + // TODO: check whether limiting the spin count in the following two methods is beneficial + // (e.g. never limit more than 1000 times), since both methods could fall back to not + // using the buffer at all (take fallback: create a new buffer, add fallback: just drop) + @tailrec - final def takeBufferFromPool(): ByteBuffer = + private def takeBufferFromPool(): ByteBuffer = if (state.compareAndSet(Unlocked, Locked)) try pool match { case Nil ⇒ allocate(bufferSize) // we have no more buffer available, so create a new one @@ -57,7 +61,7 @@ class DirectByteBufferPool(bufferSize: Int, maxPoolSize: Int) { else takeBufferFromPool() // spin while locked @tailrec - final def addBufferToPool(buf: ByteBuffer): Unit = + private def addBufferToPool(buf: ByteBuffer): Unit = if (state.compareAndSet(Unlocked, Locked)) { buf.clear() // ensure that we never have dirty buffers in the pool pool = buf :: pool diff --git a/akka-io/src/main/scala/akka/io/Tcp.scala b/akka-io/src/main/scala/akka/io/Tcp.scala index 4107ee3e82..fc8beae4f9 100644 --- a/akka-io/src/main/scala/akka/io/Tcp.scala +++ b/akka-io/src/main/scala/akka/io/Tcp.scala @@ -17,6 +17,7 @@ import java.net.ServerSocket import scala.concurrent.duration._ import scala.collection.immutable import akka.actor.ActorSystem +import com.typesafe.config.Config object Tcp extends ExtensionKey[TcpExt] { @@ -126,24 +127,16 @@ object Tcp extends ExtensionKey[TcpExt] { } } - case class Stats(channelsOpened: Long, channelsClosed: Long, selectorStats: Seq[SelectorStats]) { - def channelsOpen = channelsOpened - channelsClosed - } - - case class SelectorStats(channelsOpened: Long, channelsClosed: Long) { - def channelsOpen = channelsOpened - channelsClosed - } - /// COMMANDS sealed trait Command case class Connect(remoteAddress: InetSocketAddress, localAddress: Option[InetSocketAddress] = None, - options: immutable.Seq[SocketOption] = Nil) extends Command + options: immutable.Traversable[SocketOption] = Nil) extends Command case class Bind(handler: ActorRef, endpoint: InetSocketAddress, backlog: Int = 100, - options: immutable.Seq[SocketOption] = Nil) extends Command + options: immutable.Traversable[SocketOption] = Nil) extends Command case class Register(handler: ActorRef) extends Command case object Unbind extends Command @@ -172,8 +165,6 @@ object Tcp extends ExtensionKey[TcpExt] { case object StopReading extends Command case object ResumeReading extends Command - case object GetStats extends Command - /// EVENTS sealed trait Event @@ -193,10 +184,9 @@ object Tcp extends ExtensionKey[TcpExt] { /// INTERNAL case class RegisterOutgoingConnection(channel: SocketChannel) case class RegisterServerSocketChannel(channel: ServerSocketChannel) - case class RegisterIncomingConnection(channel: SocketChannel, handler: ActorRef, options: immutable.Seq[SocketOption]) - case class CreateConnection(channel: SocketChannel, handler: ActorRef, options: immutable.Seq[SocketOption]) - case class Reject(command: Command, retriesLeft: Int, commander: ActorRef) - case class Retry(command: Command, retriesLeft: Int, commander: ActorRef) + case class RegisterIncomingConnection(channel: SocketChannel, handler: ActorRef, + options: immutable.Traversable[SocketOption]) extends Command + case class Retry(command: Command, retriesLeft: Int) { require(retriesLeft >= 0) } case object ChannelConnectable case object ChannelAcceptable case object ChannelReadable @@ -208,22 +198,26 @@ object Tcp extends ExtensionKey[TcpExt] { class TcpExt(system: ExtendedActorSystem) extends IO.Extension { - object Settings { - val config = system.settings.config.getConfig("akka.io.tcp") + val Settings = new Settings(system.settings.config.getConfig("akka.io.tcp")) + class Settings private[TcpExt] (config: Config) { import config._ val NrOfSelectors = getInt("nr-of-selectors") val MaxChannels = getInt("max-channels") - val SelectTimeout = - if (getString("select-timeout") == "infinite") Duration.Inf - else Duration(getMilliseconds("select-timeout"), MILLISECONDS) + val SelectTimeout = getString("select-timeout") match { + case "infinite" ⇒ Duration.Inf + case x ⇒ Duration(x) + } + if (getString("select-timeout") == "infinite") Duration.Inf + else Duration(getMilliseconds("select-timeout"), MILLISECONDS) val SelectorAssociationRetries = getInt("selector-association-retries") val BatchAcceptLimit = getInt("batch-accept-limit") val DirectBufferSize = getInt("direct-buffer-size") val MaxDirectBufferPoolSize = getInt("max-direct-buffer-pool-size") - val RegisterTimeout = - if (getString("register-timeout") == "infinite") Duration.Undefined - else Duration(getMilliseconds("register-timeout"), MILLISECONDS) + val RegisterTimeout = getString("register-timeout") match { + case "infinite" ⇒ Duration.Undefined + case x ⇒ Duration(x) + } val SelectorDispatcher = getString("selector-dispatcher") val WorkerDispatcher = getString("worker-dispatcher") val ManagementDispatcher = getString("management-dispatcher") @@ -238,8 +232,11 @@ class TcpExt(system: ExtendedActorSystem) extends IO.Extension { val MaxChannelsPerSelector = MaxChannels / NrOfSelectors } - val manager = system.asInstanceOf[ActorSystemImpl].systemActorOf( - Props.empty.withDispatcher(Settings.ManagementDispatcher), "IO-TCP") + val manager = { + system.asInstanceOf[ActorSystemImpl].systemActorOf( + props = Props(new TcpManager(this)).withDispatcher(Settings.ManagementDispatcher), + name = "IO-TCP") + } val bufferPool = new DirectByteBufferPool(Settings.DirectBufferSize, Settings.MaxDirectBufferPoolSize) } diff --git a/akka-io/src/main/scala/akka/io/TcpConnection.scala b/akka-io/src/main/scala/akka/io/TcpConnection.scala index 03069f0d01..265b005fa8 100644 --- a/akka-io/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-io/src/main/scala/akka/io/TcpConnection.scala @@ -20,11 +20,9 @@ import java.nio.ByteBuffer * Base class for TcpIncomingConnection and TcpOutgoingConnection. */ abstract class TcpConnection(val selector: ActorRef, - val channel: SocketChannel) extends Actor with ActorLogging with WithBufferPool { - val tcp = Tcp(context.system) - - channel.configureBlocking(false) - + val channel: SocketChannel, + val tcp: TcpExt) extends Actor with ActorLogging with WithBufferPool { + import tcp.Settings._ var pendingWrite: PendingWrite = null // Needed to send the ConnectionClosed message in the postStop handler. @@ -33,15 +31,12 @@ abstract class TcpConnection(val selector: ActorRef, def writePending = pendingWrite ne null - def registerTimeout = tcp.Settings.RegisterTimeout - def traceLoggingEnabled = tcp.Settings.TraceLogging - // STATES /** connection established, waiting for registration from user handler */ def waitingForRegistration(commander: ActorRef): Receive = { case Register(handler) ⇒ - if (traceLoggingEnabled) log.debug("{} registered as connection handler", handler) + if (TraceLogging) log.debug("{} registered as connection handler", handler) selector ! ReadInterest context.setReceiveTimeout(Duration.Undefined) @@ -55,7 +50,7 @@ abstract class TcpConnection(val selector: ActorRef, case ReceiveTimeout ⇒ // after sending `Register` user should watch this actor to make sure // it didn't die because of the timeout - log.warning("Configured registration timeout of {} expired, stopping", registerTimeout) + log.warning("Configured registration timeout of {} expired, stopping", RegisterTimeout) context.stop(self) } @@ -66,7 +61,7 @@ abstract class TcpConnection(val selector: ActorRef, case ChannelReadable ⇒ doRead(handler) case write: Write if writePending ⇒ - if (traceLoggingEnabled) log.debug("Dropping write because queue is full") + if (TraceLogging) log.debug("Dropping write because queue is full") sender ! CommandFailed(write) case write: Write if write.data.isEmpty ⇒ @@ -107,14 +102,14 @@ abstract class TcpConnection(val selector: ActorRef, // AUXILIARIES and IMPLEMENTATION /** use in subclasses to start the common machinery above once a channel is connected */ - def completeConnect(commander: ActorRef, options: immutable.Seq[SocketOption]): Unit = { + def completeConnect(commander: ActorRef, options: immutable.Traversable[SocketOption]): Unit = { options.foreach(_.afterConnect(channel.socket)) commander ! Connected( channel.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress], channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]) - context.setReceiveTimeout(registerTimeout) + context.setReceiveTimeout(RegisterTimeout) context.become(waitingForRegistration(commander)) } @@ -126,7 +121,7 @@ abstract class TcpConnection(val selector: ActorRef, buffer.flip() if (readBytes > 0) { - if (traceLoggingEnabled) log.debug("Read {} bytes", readBytes) + if (TraceLogging) log.debug("Read {} bytes", readBytes) handler ! Received(ByteString(buffer)) releaseBuffer(buffer) @@ -135,10 +130,10 @@ abstract class TcpConnection(val selector: ActorRef, self ! ChannelReadable else selector ! ReadInterest } else if (readBytes == 0) { - if (traceLoggingEnabled) log.debug("Read nothing. Registering read interest with selector") + if (TraceLogging) log.debug("Read nothing. Registering read interest with selector") selector ! ReadInterest } else if (readBytes == -1) { - if (traceLoggingEnabled) log.debug("Read returned end-of-stream") + if (TraceLogging) log.debug("Read returned end-of-stream") doCloseConnection(handler, closeReason) } else throw new IllegalStateException("Unexpected value returned from read: " + readBytes) @@ -150,7 +145,7 @@ abstract class TcpConnection(val selector: ActorRef, def doWrite(handler: ActorRef): Unit = { try { val writtenBytes = channel.write(pendingWrite.buffer) - if (traceLoggingEnabled) log.debug("Wrote {} bytes to channel", writtenBytes) + if (TraceLogging) log.debug("Wrote {} bytes to channel", writtenBytes) if (pendingWrite.hasData) selector ! WriteInterest // still data to write else if (pendingWrite.wantsAck) { // everything written @@ -169,20 +164,20 @@ abstract class TcpConnection(val selector: ActorRef, def handleClose(handler: ActorRef, closedEvent: ConnectionClosed): Unit = if (closedEvent == Aborted) { // close instantly - if (traceLoggingEnabled) log.debug("Got Abort command. RESETing connection.") + if (TraceLogging) log.debug("Got Abort command. RESETing connection.") doCloseConnection(handler, closedEvent) } else if (writePending) { // finish writing first - if (traceLoggingEnabled) log.debug("Got Close command but write is still pending.") + if (TraceLogging) log.debug("Got Close command but write is still pending.") context.become(closingWithPendingWrite(handler, closedEvent)) } else if (closedEvent == ConfirmedClosed) { // shutdown output and wait for confirmation - if (traceLoggingEnabled) log.debug("Got ConfirmedClose command, sending FIN.") + if (TraceLogging) log.debug("Got ConfirmedClose command, sending FIN.") channel.socket.shutdownOutput() context.become(closing(handler)) } else { // close now - if (traceLoggingEnabled) log.debug("Got Close command, closing connection.") + if (TraceLogging) log.debug("Got Close command, closing connection.") doCloseConnection(handler, closedEvent) } @@ -222,7 +217,7 @@ abstract class TcpConnection(val selector: ActorRef, case NonFatal(e) ⇒ // setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574 // (also affected: OS/X Java 1.6.0_37) - if (traceLoggingEnabled) log.debug("setSoLinger(true, 0) failed with {}", e) + if (TraceLogging) log.debug("setSoLinger(true, 0) failed with {}", e) } channel.close() } diff --git a/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala b/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala index 009621f032..902dde0233 100644 --- a/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala +++ b/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala @@ -15,8 +15,10 @@ import Tcp.SocketOption */ class TcpIncomingConnection(_selector: ActorRef, _channel: SocketChannel, + _tcp: TcpExt, handler: ActorRef, - options: immutable.Seq[SocketOption]) extends TcpConnection(_selector, _channel) { + options: immutable.Traversable[SocketOption]) + extends TcpConnection(_selector, _channel, _tcp) { context.watch(handler) // sign death pact diff --git a/akka-io/src/main/scala/akka/io/TcpListener.scala b/akka-io/src/main/scala/akka/io/TcpListener.scala index 1f339be702..cf40d35a23 100644 --- a/akka-io/src/main/scala/akka/io/TcpListener.scala +++ b/akka-io/src/main/scala/akka/io/TcpListener.scala @@ -12,15 +12,15 @@ import scala.util.control.NonFatal import akka.actor.{ ActorLogging, ActorRef, Actor } import Tcp._ -class TcpListener(manager: ActorRef, - selector: ActorRef, +class TcpListener(selector: ActorRef, handler: ActorRef, endpoint: InetSocketAddress, backlog: Int, bindCommander: ActorRef, - options: immutable.Seq[SocketOption]) extends Actor with ActorLogging { + settings: TcpExt#Settings, + options: immutable.Traversable[SocketOption]) extends Actor with ActorLogging { - val batchAcceptLimit = Tcp(context.system).Settings.BatchAcceptLimit + context.watch(handler) // sign death pact val channel = { val serverSocketChannel = ServerSocketChannel.open serverSocketChannel.configureBlocking(false) @@ -30,7 +30,6 @@ class TcpListener(manager: ActorRef, serverSocketChannel } selector ! RegisterServerSocketChannel(channel) - context.watch(bindCommander) // sign death pact log.debug("Successfully bound to {}", endpoint) def receive: Receive = { @@ -41,7 +40,14 @@ class TcpListener(manager: ActorRef, def bound: Receive = { case ChannelAcceptable ⇒ - acceptAllPending(batchAcceptLimit) + acceptAllPending(settings.BatchAcceptLimit) + + case CommandFailed(RegisterIncomingConnection(socketChannel, _, _)) ⇒ + log.warning("Could not register incoming connection since capacity limit is reached, closing connection") + try socketChannel.close() + catch { + case NonFatal(e) ⇒ log.error(e, "Error closing channel") + } case Unbind ⇒ log.debug("Unbinding endpoint {}", endpoint) @@ -60,15 +66,19 @@ class TcpListener(manager: ActorRef, } if (socketChannel != null) { log.debug("New connection accepted") - manager ! RegisterIncomingConnection(socketChannel, handler, options) - selector ! AcceptInterest + socketChannel.configureBlocking(false) + context.parent ! RegisterIncomingConnection(socketChannel, handler, options) acceptAllPending(limit - 1) } - } + } else selector ! AcceptInterest override def postStop() { - try channel.close() - catch { + try { + if (channel.isOpen) { + log.debug("Closing serverSocketChannel after being stopped") + channel.close() + } + } catch { case NonFatal(e) ⇒ log.error(e, "Error closing ServerSocketChannel") } } diff --git a/akka-io/src/main/scala/akka/io/TcpManager.scala b/akka-io/src/main/scala/akka/io/TcpManager.scala index bba033dc9b..6f2046d75f 100644 --- a/akka-io/src/main/scala/akka/io/TcpManager.scala +++ b/akka-io/src/main/scala/akka/io/TcpManager.scala @@ -4,12 +4,8 @@ package akka.io -import scala.concurrent.Future -import scala.concurrent.duration._ import akka.actor.{ ActorLogging, Actor, Props } import akka.routing.RandomRouter -import akka.util.Timeout -import akka.pattern.{ ask, pipe } import Tcp._ /** @@ -47,36 +43,13 @@ import Tcp._ * with a [[akka.io.Tcp.CommandFailed]] message. This message contains the original command for reference. * */ -class TcpManager extends Actor with ActorLogging { - val settings = Tcp(context.system).Settings - val selectorNr = Iterator.from(0) +class TcpManager(tcp: TcpExt) extends Actor with ActorLogging { val selectorPool = context.actorOf( - props = Props(new TcpSelector(self)).withRouter(RandomRouter(settings.NrOfSelectors)), - name = selectorNr.next().toString) + props = Props(new TcpSelector(self, tcp)).withRouter(RandomRouter(tcp.Settings.NrOfSelectors)), + name = "selectors") def receive = { - case RegisterIncomingConnection(channel, handler, options) ⇒ - selectorPool ! CreateConnection(channel, handler, options) - - case c: Connect ⇒ - selectorPool forward c - - case b: Bind ⇒ - selectorPool forward b - - case Reject(command, 0, commander) ⇒ - log.warning("Command '{}' failed since all {} selectors are at capacity", command, context.children.size) - commander ! CommandFailed(command) - - case Reject(command, retriesLeft, commander) ⇒ - log.warning("Command '{}' rejected by {} with {} retries left, retrying...", command, sender, retriesLeft) - selectorPool ! Retry(command, retriesLeft - 1, commander) - - case GetStats ⇒ - import context.dispatcher - implicit val timeout: Timeout = 1 second span - val seqFuture = Future.traverse(context.children)(_.ask(GetStats).mapTo[SelectorStats]) - seqFuture.map(s ⇒ Stats(s.map(_.channelsOpen).sum, s.map(_.channelsClosed).sum, s.toSeq)) pipeTo sender + case x @ (_: Connect | _: Bind) ⇒ selectorPool forward x } } diff --git a/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala b/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala index cfdc7a2af0..b33e6605e8 100644 --- a/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala +++ b/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala @@ -16,11 +16,13 @@ import Tcp._ * to be established. */ class TcpOutgoingConnection(_selector: ActorRef, + _tcp: TcpExt, commander: ActorRef, remoteAddress: InetSocketAddress, localAddress: Option[InetSocketAddress], - options: immutable.Seq[SocketOption]) - extends TcpConnection(_selector, SocketChannel.open()) { + options: immutable.Traversable[SocketOption]) + extends TcpConnection(_selector, TcpOutgoingConnection.newSocketChannel(), _tcp) { + context.watch(commander) // sign death pact localAddress.foreach(channel.socket.bind) @@ -36,7 +38,7 @@ class TcpOutgoingConnection(_selector: ActorRef, def receive: Receive = PartialFunction.empty - def connecting(commander: ActorRef, options: immutable.Seq[SocketOption]): Receive = { + def connecting(commander: ActorRef, options: immutable.Traversable[SocketOption]): Receive = { case ChannelConnectable ⇒ try { val connected = channel.finishConnect() @@ -49,3 +51,11 @@ class TcpOutgoingConnection(_selector: ActorRef, } } + +object TcpOutgoingConnection { + private def newSocketChannel() = { + val channel = SocketChannel.open() + channel.configureBlocking(false) + channel + } +} diff --git a/akka-io/src/main/scala/akka/io/TcpSelector.scala b/akka-io/src/main/scala/akka/io/TcpSelector.scala index 189755c079..48f1c5ab36 100644 --- a/akka-io/src/main/scala/akka/io/TcpSelector.scala +++ b/akka-io/src/main/scala/akka/io/TcpSelector.scala @@ -14,69 +14,58 @@ import scala.concurrent.duration._ import akka.actor._ import Tcp._ -class TcpSelector(manager: ActorRef) extends Actor with ActorLogging { - @volatile var childrenKeys = HashMap.empty[String, SelectionKey] - var channelsOpened = 0L - var channelsClosed = 0L - val sequenceNumber = Iterator.from(0) - val settings = Tcp(context.system).Settings - val selectorManagementDispatcher = context.system.dispatchers.lookup(settings.SelectorDispatcher) - val selector = SelectorProvider.provider.openSelector - val doSelect: () ⇒ Int = - settings.SelectTimeout match { - case Duration.Zero ⇒ () ⇒ selector.selectNow() - case Duration.Inf ⇒ () ⇒ selector.select() - case x ⇒ val millis = x.toMillis; () ⇒ selector.select(millis) - } +class TcpSelector(manager: ActorRef, tcp: TcpExt) extends Actor with ActorLogging { + import tcp.Settings._ - selectorManagementDispatcher.execute(select) // start selection "loop" + @volatile var childrenKeys = HashMap.empty[String, SelectionKey] + val sequenceNumber = Iterator.from(0) + val selectorManagementDispatcher = context.system.dispatchers.lookup(SelectorDispatcher) + val selector = SelectorProvider.provider.openSelector + val OP_READ_AND_WRITE = OP_READ + OP_WRITE // compile-time constant def receive: Receive = { case WriteInterest ⇒ execute(enableInterest(OP_WRITE, sender)) case ReadInterest ⇒ execute(enableInterest(OP_READ, sender)) case AcceptInterest ⇒ execute(enableInterest(OP_ACCEPT, sender)) - case CreateConnection(channel, handler, options) ⇒ - val connection = context.actorOf( - props = Props( - creator = () ⇒ new TcpIncomingConnection(self, channel, handler, options), - dispatcher = settings.WorkerDispatcher), - name = nextName) - execute(registerIncomingConnection(channel, handler)) - context.watch(connection) - channelsOpened += 1 + case cmd: RegisterIncomingConnection ⇒ + handleIncomingConnection(cmd, SelectorAssociationRetries) case cmd: Connect ⇒ - handleConnect(cmd, settings.SelectorAssociationRetries, sender) + handleConnect(cmd, SelectorAssociationRetries) - case Retry(cmd: Connect, retriesLeft, commander) ⇒ - handleConnect(cmd, retriesLeft, commander) + case cmd: Bind ⇒ + handleBind(cmd, SelectorAssociationRetries) case RegisterOutgoingConnection(channel) ⇒ execute(registerOutgoingConnection(channel, sender)) - case cmd: Bind ⇒ - handleBind(cmd, settings.SelectorAssociationRetries, sender) - - case Retry(cmd: Bind, retriesLeft, commander) ⇒ - handleBind(cmd, retriesLeft, commander) - case RegisterServerSocketChannel(channel) ⇒ execute(registerListener(channel, sender)) + case Retry(command, 0) ⇒ + log.warning("Command '{}' failed since all selectors are at capacity", command) + sender ! CommandFailed(command) + + case Retry(cmd: RegisterIncomingConnection, retriesLeft) ⇒ + handleIncomingConnection(cmd, retriesLeft) + + case Retry(cmd: Connect, retriesLeft) ⇒ + handleConnect(cmd, retriesLeft) + + case Retry(cmd: Bind, retriesLeft) ⇒ + handleBind(cmd, retriesLeft) + case Terminated(child) ⇒ execute(unregister(child)) - channelsClosed += 1 - - case GetStats ⇒ - sender ! SelectorStats(channelsOpened, channelsClosed) } override def postStop() { try { - import scala.collection.JavaConverters._ - selector.keys.asScala.foreach(_.channel.close()) - selector.close() + try { + val iterator = selector.keys.iterator + while (iterator.hasNext) iterator.next().channel.close() + } finally selector.close() } catch { case NonFatal(e) ⇒ log.error(e, "Error closing selector or key") } @@ -85,35 +74,43 @@ class TcpSelector(manager: ActorRef) extends Actor with ActorLogging { // we can never recover from failures of a connection or listener child override def supervisorStrategy = SupervisorStrategy.stoppingStrategy - def handleConnect(cmd: Connect, retriesLeft: Int, commander: ActorRef): Unit = { + def handleIncomingConnection(cmd: RegisterIncomingConnection, retriesLeft: Int): Unit = + withCapacityProtection(cmd, retriesLeft) { + import cmd._ + val connection = spawnChild(() ⇒ new TcpIncomingConnection(self, channel, tcp, handler, options)) + execute(registerIncomingConnection(channel, connection)) + } + + def handleConnect(cmd: Connect, retriesLeft: Int): Unit = + withCapacityProtection(cmd, retriesLeft) { + import cmd._ + val commander = sender + spawnChild(() ⇒ new TcpOutgoingConnection(self, tcp, commander, remoteAddress, localAddress, options)) + } + + def handleBind(cmd: Bind, retriesLeft: Int): Unit = + withCapacityProtection(cmd, retriesLeft) { + import cmd._ + val commander = sender + spawnChild(() ⇒ new TcpListener(self, handler, endpoint, backlog, commander, tcp.Settings, options)) + } + + def withCapacityProtection(cmd: Command, retriesLeft: Int)(body: ⇒ Unit): Unit = { log.debug("Executing {}", cmd) - if (canHandleMoreChannels) { - val connection = context.actorOf( - props = Props( - creator = () ⇒ new TcpOutgoingConnection(self, commander, cmd.remoteAddress, cmd.localAddress, cmd.options), - dispatcher = settings.WorkerDispatcher), - name = nextName) - context.watch(connection) - channelsOpened += 1 - } else sender ! Reject(cmd, retriesLeft, commander) + if (MaxChannelsPerSelector == 0 || childrenKeys.size < MaxChannelsPerSelector) { + body + } else { + log.warning("Rejecting '{}' with {} retries left, retrying...", cmd, retriesLeft) + context.parent forward Retry(cmd, retriesLeft - 1) + } } - def handleBind(cmd: Bind, retriesLeft: Int, commander: ActorRef): Unit = { - log.debug("Executing {}", cmd) - if (canHandleMoreChannels) { - val listener = context.actorOf( - props = Props( - creator = () ⇒ new TcpListener(manager, self, cmd.handler, cmd.endpoint, cmd.backlog, commander, cmd.options), - dispatcher = settings.WorkerDispatcher), - name = nextName) - context.watch(listener) - channelsOpened += 1 - } else sender ! Reject(cmd, retriesLeft, commander) - } - - def nextName = sequenceNumber.next().toString - - def canHandleMoreChannels = childrenKeys.size < settings.MaxChannelsPerSelector + def spawnChild(creator: () ⇒ Actor) = + context.watch { + context.actorOf( + props = Props(creator, dispatcher = WorkerDispatcher), + name = sequenceNumber.next().toString) + } //////////////// Management Tasks scheduled via the selectorManagementDispatcher ///////////// @@ -172,18 +169,28 @@ class TcpSelector(manager: ActorRef) extends Actor with ActorLogging { } val select = new Task { + val doSelect: () ⇒ Int = + SelectTimeout match { + case Duration.Zero ⇒ () ⇒ selector.selectNow() + case Duration.Inf ⇒ () ⇒ selector.select() + case x ⇒ val millis = x.toMillis; () ⇒ selector.select(millis) + } def tryRun() { if (doSelect() > 0) { val keys = selector.selectedKeys val iterator = keys.iterator() while (iterator.hasNext) { val key = iterator.next - val connection = key.attachment.asInstanceOf[ActorRef] if (key.isValid) { - if (key.isReadable) connection ! ChannelReadable - if (key.isWritable) connection ! ChannelWritable - else if (key.isAcceptable) connection ! ChannelAcceptable - else if (key.isConnectable) connection ! ChannelConnectable + val connection = key.attachment.asInstanceOf[ActorRef] + key.readyOps match { + case OP_READ ⇒ connection ! ChannelReadable + case OP_WRITE ⇒ connection ! ChannelWritable + case OP_READ_AND_WRITE ⇒ connection ! ChannelWritable; connection ! ChannelReadable + case x if (x & OP_ACCEPT) > 0 ⇒ connection ! ChannelAcceptable + case x if (x & OP_CONNECT) > 0 ⇒ connection ! ChannelConnectable + case x ⇒ log.warning("Invalid readyOps: {}", x) + } key.interestOps(0) // prevent immediate reselection by always clearing } else log.warning("Invalid selection key: {}", key) } @@ -193,12 +200,15 @@ class TcpSelector(manager: ActorRef) extends Actor with ActorLogging { } } + selectorManagementDispatcher.execute(select) // start selection "loop" + abstract class Task extends Runnable { def tryRun() def run() { try tryRun() catch { - case NonFatal(e) ⇒ log.error(e, "Error during selector management task: {}", e) + case _: java.nio.channels.ClosedSelectorException ⇒ // ok, expected during shutdown + case NonFatal(e) ⇒ log.error(e, "Error during selector management task: {}", e) } } } diff --git a/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala index 0265781c90..9dc925a390 100644 --- a/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala +++ b/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -20,7 +20,7 @@ import akka.util.ByteString import Tcp._ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") { - val serverAddress = TemporaryServerAddress.get("127.0.0.1") + val serverAddress = TemporaryServerAddress("127.0.0.1") "An outgoing connection" must { // common behavior @@ -245,7 +245,7 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") assertActorTerminated(connectionActor) } - val UnboundAddress = TemporaryServerAddress.get("127.0.0.1") + val UnboundAddress = TemporaryServerAddress("127.0.0.1") "report failed connection attempt when target is unreachable" in withUnacceptedConnection(connectionActorCons = createConnectionActor(serverAddress = UnboundAddress)) { setup ⇒ import setup._ @@ -397,7 +397,7 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") commander: ActorRef): TestActorRef[TcpOutgoingConnection] = { TestActorRef( - new TcpOutgoingConnection(selector, commander, serverAddress, localAddress, options) { + new TcpOutgoingConnection(selector, Tcp(system), commander, serverAddress, localAddress, options) { override def postRestart(reason: Throwable) { // ensure we never restart context.stop(self) diff --git a/akka-io/src/test/scala/akka/io/TcpListenerSpec.scala b/akka-io/src/test/scala/akka/io/TcpListenerSpec.scala index 0619038c54..8275170b8f 100644 --- a/akka-io/src/test/scala/akka/io/TcpListenerSpec.scala +++ b/akka-io/src/test/scala/akka/io/TcpListenerSpec.scala @@ -4,56 +4,85 @@ package akka.io -import java.net.{ Socket, InetSocketAddress } -import java.nio.channels.ServerSocketChannel +import java.net.Socket import scala.concurrent.duration._ -import scala.util.Success +import akka.actor.{ Terminated, SupervisorStrategy, Actor, Props } import akka.testkit.{ TestProbe, TestActorRef, AkkaSpec } -import akka.util.Timeout -import akka.pattern.ask import Tcp._ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { + "A TcpListener" must { - val manager = TestProbe() - val selector = TestProbe() - val handler = TestProbe() - val handlerRef = handler.ref - val bindCommander = TestProbe() - val endpoint = TemporaryServerAddress.get("127.0.0.1") - val listener = TestActorRef(new TcpListener(manager.ref, selector.ref, handler.ref, endpoint, 100, - bindCommander.ref, Nil)) - var serverSocketChannel: Option[ServerSocketChannel] = None "register its ServerSocketChannel with its selector" in { - val RegisterServerSocketChannel(channel) = selector.receiveOne(Duration.Zero) - serverSocketChannel = Some(channel) + val setup = ListenerSetup(autoBind = false) + import setup._ + + selector.expectMsgType[RegisterServerSocketChannel] } "let the Bind commander know when binding is completed" in { + ListenerSetup() + } + + "accept acceptable connections register them with its parent" in { + val setup = ListenerSetup() + import setup._ + + new Socket(endpoint.getHostName, endpoint.getPort) + new Socket(endpoint.getHostName, endpoint.getPort) + new Socket(endpoint.getHostName, endpoint.getPort) + + listener ! ChannelAcceptable + + val handlerRef = handler.ref + parent.expectMsgPF() { case RegisterIncomingConnection(_, `handlerRef`, Nil) ⇒ } + parent.expectMsgPF() { case RegisterIncomingConnection(_, `handlerRef`, Nil) ⇒ } + parent.expectNoMsg(100.millis) + + listener ! ChannelAcceptable + + parent.expectMsgPF() { case RegisterIncomingConnection(_, `handlerRef`, Nil) ⇒ } + } + + "react to Unbind commands by replying with Unbound and stopping itself" in { + val setup = ListenerSetup() + import setup._ + + val unbindCommander = TestProbe() + unbindCommander.send(listener, Unbind) + + unbindCommander.expectMsg(Unbound) + parent.expectMsgType[Terminated].actor must be(listener) + } + } + + val counter = Iterator.from(0) + + case class ListenerSetup(autoBind: Boolean = true) { + val selector = TestProbe() + val handler = TestProbe() + val bindCommander = TestProbe() + val parent = TestProbe() + val endpoint = TemporaryServerAddress("127.0.0.1") + private val parentRef = TestActorRef(new ListenerParent) + + if (autoBind) { listener ! Bound bindCommander.expectMsg(Bound) } - "accept two acceptable connections at once and register them with the manager" in { - new Socket("localhost", endpoint.getPort) - new Socket("localhost", endpoint.getPort) - new Socket("localhost", endpoint.getPort) - listener ! ChannelAcceptable - val RegisterIncomingConnection(_, `handlerRef`, Nil) = manager.receiveOne(Duration.Zero) - val RegisterIncomingConnection(_, `handlerRef`, Nil) = manager.receiveOne(Duration.Zero) - } - - "accept one more connection and register it with the manager" in { - listener ! ChannelAcceptable - val RegisterIncomingConnection(_, `handlerRef`, Nil) = manager.receiveOne(Duration.Zero) - } - - "react to Unbind commands by closing the ServerSocketChannel, replying with Unbound and stopping itself" in { - implicit val timeout: Timeout = 1 second span - listener.ask(Unbind).value must equal(Some(Success(Unbound))) - serverSocketChannel.get.isOpen must equal(false) - listener.isTerminated must equal(true) + def listener = parentRef.underlyingActor.listener + private class ListenerParent extends Actor { + val listener = context.actorOf( + props = Props(new TcpListener(selector.ref, handler.ref, endpoint, 100, bindCommander.ref, + Tcp(system).Settings, Nil)), + name = "test-listener-" + counter.next()) + parent.watch(listener) + def receive: Receive = { + case msg ⇒ parent.ref forward msg + } + override def supervisorStrategy = SupervisorStrategy.stoppingStrategy } } diff --git a/akka-io/src/test/scala/akka/io/TemporaryServerAddress.scala b/akka-io/src/test/scala/akka/io/TemporaryServerAddress.scala index f006ed70e5..20b4e3c16b 100644 --- a/akka-io/src/test/scala/akka/io/TemporaryServerAddress.scala +++ b/akka-io/src/test/scala/akka/io/TemporaryServerAddress.scala @@ -1,10 +1,15 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ + package akka.io import java.nio.channels.ServerSocketChannel import java.net.InetSocketAddress object TemporaryServerAddress { - def get(address: String): InetSocketAddress = { + + def apply(address: String = "127.0.0.1"): InetSocketAddress = { val serverSocket = ServerSocketChannel.open() serverSocket.socket.bind(new InetSocketAddress(address, 0)) val port = serverSocket.socket.getLocalPort