diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala index 431ece8ddb..0c8b817560 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -7,7 +7,7 @@ package akka.io import java.io.{ File, IOException } import java.net.{ URLClassLoader, ConnectException, InetSocketAddress, SocketException } import java.nio.ByteBuffer -import java.nio.channels.{ SelectionKey, Selector, ServerSocketChannel, SocketChannel } +import java.nio.channels._ import java.nio.channels.spi.SelectorProvider import java.nio.channels.SelectionKey._ import scala.annotation.tailrec @@ -18,7 +18,7 @@ import org.scalatest.matchers._ import akka.io.Tcp._ import akka.io.SelectionHandler._ import akka.io.Inet.SocketOption -import akka.actor.{ PoisonPill, Terminated, DeathPactException } +import akka.actor._ import akka.testkit.{ AkkaSpec, EventFilter, TestActorRef, TestProbe } import akka.util.{ Helpers, ByteString } import akka.TestUtils._ @@ -134,11 +134,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") selector.send(connectionActor, ChannelConnectable) userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress])) - // we unrealistically register the selector here so that we can observe - // the ordering between Received and ReadInterest - userHandler.send(connectionActor, Register(selector.ref)) - selector.expectMsgType[Received].data.decodeString("ASCII") must be("immediatedata") - selector.expectMsg(ReadInterest) + userHandler.send(connectionActor, Register(userHandler.ref)) + userHandler.expectMsgType[Received].data.decodeString("ASCII") must be("immediatedata") + interestCallReceiver.expectMsg(OP_READ) } } @@ -246,7 +244,7 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") // it will have to keep the rest of the piece and send it // when possible writer.send(connectionActor, firstWrite) - selector.expectMsg(WriteInterest) + interestCallReceiver.expectMsg(OP_WRITE) // send another write which should fail immediately // because we don't store more than one piece in flight @@ -275,9 +273,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") connectionHandler.send(connectionActor, SuspendReading) // the selector interprets StopReading to deregister interest for reading - selector.expectMsg(DisableReadInterest) + interestCallReceiver.expectMsg(-OP_READ) connectionHandler.send(connectionActor, ResumeReading) - selector.expectMsg(ReadInterest) + interestCallReceiver.expectMsg(OP_READ) } } @@ -706,11 +704,14 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") serverSideChannel } - abstract class LocalServerTest { + abstract class LocalServerTest extends ChannelRegistry { val localServerChannel = ServerSocketChannel.open() val userHandler = TestProbe() val selector = TestProbe() + var registerCallReceiver = TestProbe() + var interestCallReceiver = TestProbe() + def run(body: ⇒ Unit): Unit = { try { setServerSocketOptions() @@ -720,16 +721,21 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") } finally localServerChannel.close() } + def register(channel: SelectableChannel, initialOps: Int)(implicit channelActor: ActorRef): Unit = + registerCallReceiver.ref.tell(channel -> initialOps, channelActor) + def setServerSocketOptions() = () def createConnectionActor(serverAddress: InetSocketAddress = serverAddress, options: immutable.Seq[SocketOption] = Nil): TestActorRef[TcpOutgoingConnection] = { val ref = TestActorRef( - new TcpOutgoingConnection(Tcp(system), userHandler.ref, Connect(serverAddress, options = options)) { + new TcpOutgoingConnection(Tcp(system), this, userHandler.ref, Connect(serverAddress, options = options)) { override def postRestart(reason: Throwable): Unit = context.stop(self) // ensure we never restart - override def selector = LocalServerTest.this.selector.ref }) - ref ! ChannelRegistered + ref ! new ChannelRegistration { + def enableInterest(op: Int): Unit = interestCallReceiver.ref ! op + def disableInterest(op: Int): Unit = interestCallReceiver.ref ! -op + } ref } } @@ -745,7 +751,8 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") lazy val clientSideChannel = connectionActor.underlyingActor.channel override def run(body: ⇒ Unit): Unit = super.run { - selector.expectMsg(RegisterChannel(clientSideChannel, OP_CONNECT)) + registerCallReceiver.expectMsg(clientSideChannel -> OP_CONNECT) + registerCallReceiver.sender must be(connectionActor) body } } @@ -770,7 +777,7 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress])) userHandler.send(connectionActor, Register(connectionHandler.ref, keepOpenOnPeerClosed, useResumeWriting)) - selector.expectMsg(ReadInterest) + interestCallReceiver.expectMsg(OP_READ) clientSelectionKey // trigger initialization serverSelectionKey // trigger initialization @@ -781,7 +788,7 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") } } - val TestSize = 10000 + final val TestSize = 10000 // compile-time constant def writeCmd(ack: AnyRef) = Write(ByteString(Array.fill[Byte](TestSize)(0)), ack) @@ -822,12 +829,12 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") if (remainingTries <= 0) throw new AssertionError("Pulling took too many loops, remaining data: " + remaining) else if (remaining > 0) { - if (selector.msgAvailable) { - selector.expectMsg(WriteInterest) - clientSelectionKey.interestOps(SelectionKey.OP_WRITE) + if (interestCallReceiver.msgAvailable) { + interestCallReceiver.expectMsg(OP_WRITE) + clientSelectionKey.interestOps(OP_WRITE) } - serverSelectionKey.interestOps(SelectionKey.OP_READ) + serverSelectionKey.interestOps(OP_READ) nioSelector.select(10) if (nioSelector.selectedKeys().contains(clientSelectionKey)) { clientSelectionKey.interestOps(0) diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpListenerSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpListenerSpec.scala index 6bb12241b8..44757f182e 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpListenerSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpListenerSpec.scala @@ -5,15 +5,15 @@ package akka.io import java.net.Socket -import java.nio.channels.SocketChannel +import java.nio.channels.{ SelectableChannel, SocketChannel } +import java.nio.channels.SelectionKey.OP_ACCEPT import scala.concurrent.duration._ -import akka.actor.{ Terminated, SupervisorStrategy, Actor, Props } -import akka.testkit.{ TestProbe, TestActorRef, AkkaSpec } -import Tcp._ -import akka.testkit.EventFilter -import akka.io.SelectionHandler._ +import akka.actor._ +import akka.testkit.{ TestProbe, TestActorRef, AkkaSpec, EventFilter } import akka.io.TcpListener.{ RegisterIncoming, FailedRegisterIncoming } +import akka.io.SelectionHandler._ import akka.TestUtils +import Tcp._ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { @@ -22,7 +22,10 @@ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { "register its ServerSocketChannel with its selector" in new TestSetup "let the Bind commander know when binding is completed" in new TestSetup { - listener ! ChannelRegistered + listener ! new ChannelRegistration { + def disableInterest(op: Int) = () + def enableInterest(op: Int) = () + } bindCommander.expectMsgType[Bound] } @@ -39,7 +42,7 @@ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { expectWorkerForCommand expectWorkerForCommand selectorRouter.expectNoMsg(100.millis) - parent.expectMsg(AcceptInterest) + interestCallReceiver.expectMsg(OP_ACCEPT) // and pick up the last remaining connection on the next ChannelAcceptable listener ! ChannelAcceptable @@ -53,13 +56,13 @@ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { listener ! ChannelAcceptable expectWorkerForCommand selectorRouter.expectNoMsg(100.millis) - parent.expectMsg(AcceptInterest) + interestCallReceiver.expectMsg(OP_ACCEPT) attemptConnectionToEndpoint() listener ! ChannelAcceptable expectWorkerForCommand selectorRouter.expectNoMsg(100.millis) - parent.expectMsg(AcceptInterest) + interestCallReceiver.expectMsg(OP_ACCEPT) } "react to Unbind commands by replying with Unbound and stopping itself" in new TestSetup { @@ -89,19 +92,26 @@ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { val counter = Iterator.from(0) - class TestSetup { setup ⇒ + class TestSetup { val handler = TestProbe() val handlerRef = handler.ref val bindCommander = TestProbe() val parent = TestProbe() val selectorRouter = TestProbe() val endpoint = TestUtils.temporaryServerAddress() + + var registerCallReceiver = TestProbe() + var interestCallReceiver = TestProbe() + private val parentRef = TestActorRef(new ListenerParent) - parent.expectMsgType[RegisterChannel] + registerCallReceiver.expectMsg(OP_ACCEPT) def bindListener() { - listener ! ChannelRegistered + listener ! new ChannelRegistration { + def enableInterest(op: Int): Unit = interestCallReceiver.ref ! op + def disableInterest(op: Int): Unit = interestCallReceiver.ref ! -op + } bindCommander.expectMsgType[Bound] } @@ -117,15 +127,19 @@ class TcpListenerSpec extends AkkaSpec("akka.io.tcp.batch-accept-limit = 2") { chan } - private class ListenerParent extends Actor { + private class ListenerParent extends Actor with ChannelRegistry { val listener = context.actorOf( - props = Props(new TcpListener(selectorRouter.ref, Tcp(system), bindCommander.ref, Bind(handler.ref, endpoint, 100, Nil))), + props = Props(classOf[TcpListener], selectorRouter.ref, Tcp(system), this, bindCommander.ref, + Bind(handler.ref, endpoint, 100, Nil)), name = "test-listener-" + counter.next()) parent.watch(listener) def receive: Receive = { case msg ⇒ parent.ref forward msg } override def supervisorStrategy = SupervisorStrategy.stoppingStrategy + + def register(channel: SelectableChannel, initialOps: Int)(implicit channelActor: ActorRef): Unit = + registerCallReceiver.ref.tell(initialOps, channelActor) } } diff --git a/akka-actor/src/main/scala/akka/io/IO.scala b/akka-actor/src/main/scala/akka/io/IO.scala index 22c8014f2c..f876b8ba9d 100644 --- a/akka-actor/src/main/scala/akka/io/IO.scala +++ b/akka-actor/src/main/scala/akka/io/IO.scala @@ -31,7 +31,7 @@ object IO { props = Props(classOf[SelectionHandler], selectorSettings).withRouter(RandomRouter(nrOfSelectors)), name = "selectors") - def workerForCommandHandler(pf: PartialFunction[HasFailureMessage, Props]): Receive = { + final def workerForCommandHandler(pf: PartialFunction[HasFailureMessage, ChannelRegistry ⇒ Props]): Receive = { case cmd: HasFailureMessage if pf.isDefinedAt(cmd) ⇒ selectorPool ! WorkerForCommand(cmd, sender, pf(cmd)) } } diff --git a/akka-actor/src/main/scala/akka/io/SelectionHandler.scala b/akka-actor/src/main/scala/akka/io/SelectionHandler.scala index 8ef370aa61..ce4f2a3251 100644 --- a/akka-actor/src/main/scala/akka/io/SelectionHandler.scala +++ b/akka-actor/src/main/scala/akka/io/SelectionHandler.scala @@ -4,22 +4,21 @@ package akka.io -import java.util.{ Set ⇒ JSet, Iterator ⇒ JIterator } +import java.util.{ Iterator ⇒ JIterator } import java.util.concurrent.atomic.AtomicBoolean -import java.nio.channels.{ Selector, SelectableChannel, SelectionKey, CancelledKeyException, ClosedSelectorException, ClosedChannelException } +import java.nio.channels.{ SelectableChannel, SelectionKey, CancelledKeyException } import java.nio.channels.SelectionKey._ -import java.nio.channels.spi.{ AbstractSelector, SelectorProvider } -import scala.annotation.{ tailrec, switch } -import scala.util.control.NonFatal -import scala.collection.immutable -import scala.concurrent.duration._ -import akka.actor._ +import java.nio.channels.spi.SelectorProvider import com.typesafe.config.Config +import scala.annotation.tailrec +import scala.util.control.NonFatal +import scala.concurrent.ExecutionContext +import akka.event.LoggingAdapter +import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } import akka.io.IO.HasFailureMessage import akka.util.Helpers.Requiring -import akka.event.LoggingAdapter import akka.util.SerializedSuspendableExecutionContext -import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } +import akka.actor._ abstract class SelectionHandlerSettings(config: Config) { import config._ @@ -36,25 +35,145 @@ abstract class SelectionHandlerSettings(config: Config) { val TraceLogging: Boolean = getBoolean("trace-logging") def MaxChannelsPerSelector: Int +} +private[io] trait ChannelRegistry { + def register(channel: SelectableChannel, initialOps: Int)(implicit channelActor: ActorRef) +} + +private[io] trait ChannelRegistration { + def enableInterest(op: Int) + def disableInterest(op: Int) } private[io] object SelectionHandler { - case class WorkerForCommand(apiCommand: HasFailureMessage, commander: ActorRef, childProps: Props) + case class WorkerForCommand(apiCommand: HasFailureMessage, commander: ActorRef, childProps: ChannelRegistry ⇒ Props) - case class RegisterChannel(channel: SelectableChannel, initialOps: Int) - case object ChannelRegistered case class Retry(command: WorkerForCommand, retriesLeft: Int) { require(retriesLeft >= 0) } case object ChannelConnectable case object ChannelAcceptable case object ChannelReadable case object ChannelWritable - case object AcceptInterest - case object ReadInterest - case object DisableReadInterest - case object WriteInterest + + private class ChannelRegistryImpl(executionContext: ExecutionContext, log: LoggingAdapter) extends ChannelRegistry { + private[this] val selector = SelectorProvider.provider.openSelector + private[this] val wakeUp = new AtomicBoolean(false) + + final val OP_READ_AND_WRITE = OP_READ | OP_WRITE // compile-time constant + + private[this] val select = new Task { + def tryRun(): Unit = { + if (selector.select() > 0) { // This assumes select return value == selectedKeys.size + val keys = selector.selectedKeys + val iterator = keys.iterator() + while (iterator.hasNext) { + val key = iterator.next() + if (key.isValid) { + try { + // Cache because the performance implications of calling this on different platforms are not clear + val readyOps = key.readyOps() + key.interestOps(key.interestOps & ~readyOps) // prevent immediate reselection by always clearing + val connection = key.attachment.asInstanceOf[ActorRef] + readyOps match { + case OP_READ ⇒ connection ! ChannelReadable + case OP_WRITE ⇒ connection ! ChannelWritable + case OP_READ_AND_WRITE ⇒ { connection ! ChannelWritable; connection ! ChannelReadable } + case x if (x & OP_ACCEPT) > 0 ⇒ connection ! ChannelAcceptable + case x if (x & OP_CONNECT) > 0 ⇒ connection ! ChannelConnectable + case x ⇒ log.warning("Invalid readyOps: [{}]", x) + } + } catch { + case _: CancelledKeyException ⇒ + // can be ignored because this exception is triggered when the key becomes invalid + // because `channel.close()` in `TcpConnection.postStop` is called from another thread + } + } + } + keys.clear() // we need to remove the selected keys from the set, otherwise they remain selected + } + wakeUp.set(false) + } + + override def run(): Unit = + if (selector.isOpen) + try super.run() + finally executionContext.execute(this) // re-schedule select behind all currently queued tasks + } + + executionContext.execute(select) // start selection "loop" + + def register(channel: SelectableChannel, initialOps: Int)(implicit channelActor: ActorRef): Unit = + execute { + new Task { + def tryRun(): Unit = { + val key = channel.register(selector, initialOps, channelActor) + channelActor ! new ChannelRegistration { + def enableInterest(ops: Int): Unit = enableInterestOps(key, ops) + def disableInterest(ops: Int): Unit = disableInterestOps(key, ops) + } + } + } + } + + def shutdown(): Unit = + execute { + new Task { + def tryRun(): Unit = { + // thorough 'close' of the Selector + @tailrec def closeNextChannel(it: JIterator[SelectionKey]): Unit = if (it.hasNext) { + try it.next().channel.close() catch { case NonFatal(e) ⇒ log.error(e, "Error closing channel") } + closeNextChannel(it) + } + try closeNextChannel(selector.keys.iterator) + finally selector.close() + } + } + } + + // always set the interest keys on the selector thread, + // benchmarks show that not doing so results in lock contention + private def enableInterestOps(key: SelectionKey, ops: Int): Unit = + execute { + new Task { + def tryRun(): Unit = { + val currentOps = key.interestOps + val newOps = currentOps | ops + if (newOps != currentOps) key.interestOps(newOps) + } + } + } + + private def disableInterestOps(key: SelectionKey, ops: Int): Unit = + execute { + new Task { + def tryRun(): Unit = { + val currentOps = key.interestOps + val newOps = currentOps & ~ops + if (newOps != currentOps) key.interestOps(newOps) + } + } + } + + private def execute(task: Task): Unit = { + executionContext.execute(task) + if (wakeUp.compareAndSet(false, true)) // if possible avoid syscall and trade off with LOCK CMPXCHG + selector.wakeup() + } + + // FIXME: Add possibility to signal failure of task to someone + private abstract class Task extends Runnable { + def tryRun() + def run() { + try tryRun() + catch { + case _: CancelledKeyException ⇒ // ok, can be triggered while setting interest ops + case NonFatal(e) ⇒ log.error(e, "Error during selector management task: [{}]", e) + } + } + } + } } private[io] class SelectionHandler(settings: SelectionHandlerSettings) extends Actor with ActorLogging @@ -62,45 +181,34 @@ private[io] class SelectionHandler(settings: SelectionHandlerSettings) extends A import SelectionHandler._ import settings._ - final val OP_READ_AND_WRITE = OP_READ | OP_WRITE // compile-time constant - - private val wakeUp = new AtomicBoolean(false) - @volatile var childrenKeys = immutable.HashMap.empty[String, SelectionKey] - var sequenceNumber = 0 - val selectorManagementEC = { + private[this] var sequenceNumber = 0 + private[this] var childCount = 0 + private[this] val registry = { val dispatcher = context.system.dispatchers.lookup(SelectorDispatcher) - SerializedSuspendableExecutionContext(dispatcher.throughput)(dispatcher) + new ChannelRegistryImpl(SerializedSuspendableExecutionContext(dispatcher.throughput)(dispatcher), log) } - val selector = SelectorProvider.provider.openSelector - def receive: Receive = { - case WriteInterest ⇒ execute(enableInterest(OP_WRITE, sender)) - case ReadInterest ⇒ execute(enableInterest(OP_READ, sender)) - case AcceptInterest ⇒ execute(enableInterest(OP_ACCEPT, sender)) + case cmd: WorkerForCommand ⇒ spawnChildWithCapacityProtection(cmd, SelectorAssociationRetries) - case DisableReadInterest ⇒ execute(disableInterest(OP_READ, sender)) + case Retry(cmd, retriesLeft) ⇒ spawnChildWithCapacityProtection(cmd, retriesLeft) - case cmd: WorkerForCommand ⇒ spawnChildWithCapacityProtection(cmd, SelectorAssociationRetries) - - case RegisterChannel(channel, initialOps) ⇒ execute(registerChannel(channel, sender, initialOps)) - - case Retry(cmd, retriesLeft) ⇒ spawnChildWithCapacityProtection(cmd, retriesLeft) - - case Terminated(child) ⇒ execute(unregister(child)) + case _: Terminated ⇒ childCount -= 1 } - override def postStop(): Unit = execute(terminate()) + override def postStop(): Unit = registry.shutdown() // we can never recover from failures of a connection or listener child override def supervisorStrategy = SupervisorStrategy.stoppingStrategy def spawnChildWithCapacityProtection(cmd: WorkerForCommand, retriesLeft: Int): Unit = { if (TraceLogging) log.debug("Executing [{}]", cmd) - if (MaxChannelsPerSelector == -1 || childrenKeys.size < MaxChannelsPerSelector) { + if (MaxChannelsPerSelector == -1 || childCount < MaxChannelsPerSelector) { val newName = sequenceNumber.toString sequenceNumber += 1 - context watch context.actorOf(props = cmd.childProps.withDispatcher(WorkerDispatcher), name = newName) + val child = context.actorOf(props = cmd.childProps(registry).withDispatcher(WorkerDispatcher), name = newName) + childCount += 1 + if (MaxChannelsPerSelector > 0) context.watch(child) // we don't need to watch if we aren't limited } else { if (retriesLeft >= 1) { log.warning("Rejecting [{}] with [{}] retries left, retrying...", cmd, retriesLeft) @@ -111,106 +219,4 @@ private[io] class SelectionHandler(settings: SelectionHandlerSettings) extends A } } } - - //////////////// Management Tasks scheduled via the selectorManagementEC ///////////// - - def execute(task: Task): Unit = { - selectorManagementEC.execute(task) - if (wakeUp.compareAndSet(false, true)) selector.wakeup() // Avoiding syscall and trade off with LOCK CMPXCHG - } - - def registerChannel(channel: SelectableChannel, channelActor: ActorRef, initialOps: Int): Task = - new Task { - def tryRun() { - childrenKeys = childrenKeys.updated(channelActor.path.name, channel.register(selector, initialOps, channelActor)) - channelActor ! ChannelRegistered - } - } - - // Always set the interest keys on the selector thread according to benchmark - def enableInterest(ops: Int, connection: ActorRef) = - new Task { - def tryRun() { - val key = childrenKeys(connection.path.name) - val currentOps = key.interestOps - val newOps = currentOps | ops - if (newOps != currentOps) key.interestOps(newOps) - } - } - - def disableInterest(ops: Int, connection: ActorRef) = - new Task { - def tryRun() { - val key = childrenKeys(connection.path.name) - val currentOps = key.interestOps - val newOps = currentOps & ~ops - if (newOps != currentOps) key.interestOps(newOps) - } - } - - def unregister(child: ActorRef) = - new Task { def tryRun() { childrenKeys = childrenKeys - child.path.name } } - - def terminate() = new Task { - def tryRun() { - // Thorough 'close' of the Selector - @tailrec def closeNextChannel(it: JIterator[SelectionKey]): Unit = if (it.hasNext) { - try it.next().channel.close() catch { case NonFatal(e) ⇒ log.error(e, "Error closing channel") } - closeNextChannel(it) - } - try closeNextChannel(selector.keys.iterator) finally selector.close() - } - } - - val select = new Task { - def tryRun(): Unit = { - if (selector.select() > 0) { // This assumes select return value == selectedKeys.size - val keys = selector.selectedKeys - val iterator = keys.iterator() - while (iterator.hasNext) { - val key = iterator.next() - if (key.isValid) { - try { - // Cache because the performance implications of calling this on different platforms are not clear - val readyOps = key.readyOps() - key.interestOps(key.interestOps & ~readyOps) // prevent immediate reselection by always clearing - val connection = key.attachment.asInstanceOf[ActorRef] - readyOps match { - case OP_READ ⇒ connection ! ChannelReadable - case OP_WRITE ⇒ connection ! ChannelWritable - case OP_READ_AND_WRITE ⇒ { connection ! ChannelWritable; connection ! ChannelReadable } - case x if (x & OP_ACCEPT) > 0 ⇒ connection ! ChannelAcceptable - case x if (x & OP_CONNECT) > 0 ⇒ connection ! ChannelConnectable - case x ⇒ log.warning("Invalid readyOps: [{}]", x) - } - } catch { - case _: CancelledKeyException ⇒ - // can be ignored because this exception is triggered when the key becomes invalid - // because `channel.close()` in `TcpConnection.postStop` is called from another thread - } - } - } - keys.clear() // we need to remove the selected keys from the set, otherwise they remain selected - } - - wakeUp.set(false) - // FIXME what is the appropriate error-handling here, shouldn't this task be resubmitted in case of exception? - selectorManagementEC.execute(this) // re-schedules select behind all currently queued tasks - } - } - - selectorManagementEC.execute(select) // start selection "loop" - - // FIXME: Add possibility to signal failure of task to someone - abstract class Task extends Runnable { - def tryRun() - def run() { - try tryRun() - catch { - case _: CancelledKeyException ⇒ // ok, can be triggered in `enableInterest` or `disableInterest` - case _: ClosedSelectorException ⇒ // ok, expected during shutdown - case NonFatal(e) ⇒ log.error(e, "Error during selector management task: [{}]", e) - } - } - } } diff --git a/akka-actor/src/main/scala/akka/io/TcpConnection.scala b/akka-actor/src/main/scala/akka/io/TcpConnection.scala index 16ba1ca7fe..ec8b893761 100644 --- a/akka-actor/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpConnection.scala @@ -5,6 +5,7 @@ package akka.io import java.net.InetSocketAddress +import java.nio.channels.SelectionKey._ import java.io.{ FileInputStream, IOException } import java.nio.channels.{ FileChannel, SocketChannel } import java.nio.ByteBuffer @@ -24,29 +25,25 @@ import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } * * INTERNAL API */ -private[io] abstract class TcpConnection( - val channel: SocketChannel, - val tcp: TcpExt) +private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketChannel) extends Actor with ActorLogging with RequiresMessageQueue[UnboundedMessageQueueSemantics] { + import tcp.Settings._ import tcp.bufferPool import TcpConnection._ - var pendingWrite: PendingWrite = null - - // Needed to send the ConnectionClosed message in the postStop handler. - var closedMessage: CloseInformation = null + private[this] var pendingWrite: PendingWrite = _ private[this] var peerClosed = false - private[this] var keepOpenOnPeerClosed = false + private[this] var writingSuspended = false + private[this] var interestedInResume: Option[ActorRef] = None + var closedMessage: CloseInformation = _ // for ConnectionClosed message in postStop def writePending = pendingWrite ne null - def selector = context.parent - // STATES /** connection established, waiting for registration from user handler */ - def waitingForRegistration(commander: ActorRef): Receive = { + def waitingForRegistration(registration: ChannelRegistration, commander: ActorRef): Receive = { case Register(handler, keepOpenOnPeerClosed, useResumeWriting) ⇒ // up to this point we've been watching the commander, // but since registration is now complete we only need to watch the handler from here on @@ -55,15 +52,15 @@ private[io] abstract class TcpConnection( context.watch(handler) } if (TraceLogging) log.debug("[{}] registered as connection handler", handler) - this.keepOpenOnPeerClosed = keepOpenOnPeerClosed - this.useResumeWriting = useResumeWriting - doRead(handler, None) // immediately try reading + val info = ConnectionInfo(registration, handler, keepOpenOnPeerClosed, useResumeWriting) + doRead(info, None) // immediately try reading context.setReceiveTimeout(Duration.Undefined) - context.become(connected(handler)) + context.become(connected(info)) case cmd: CloseCommand ⇒ - handleClose(commander, Some(sender), cmd.event) + val info = ConnectionInfo(registration, commander, keepOpenOnPeerClosed = false, useResumeWriting = false) + handleClose(info, Some(sender), cmd.event) case ReceiveTimeout ⇒ // after sending `Register` user should watch this actor to make sure @@ -73,52 +70,50 @@ private[io] abstract class TcpConnection( } /** normal connected state */ - def connected(handler: ActorRef): Receive = handleWriteMessages(handler) orElse { - case SuspendReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doRead(handler, None) - - case cmd: CloseCommand ⇒ handleClose(handler, Some(sender), cmd.event) - } + def connected(info: ConnectionInfo): Receive = + handleWriteMessages(info) orElse { + case SuspendReading ⇒ info.registration.disableInterest(OP_READ) + case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case ChannelReadable ⇒ doRead(info, None) + case cmd: CloseCommand ⇒ handleClose(info, Some(sender), cmd.event) + } /** the peer sent EOF first, but we may still want to send */ - def peerSentEOF(handler: ActorRef): Receive = handleWriteMessages(handler) orElse { - case cmd: CloseCommand ⇒ handleClose(handler, Some(sender), cmd.event) - } + def peerSentEOF(info: ConnectionInfo): Receive = + handleWriteMessages(info) orElse { + case cmd: CloseCommand ⇒ handleClose(info, Some(sender), cmd.event) + } /** connection is closing but a write has to be finished first */ - def closingWithPendingWrite(handler: ActorRef, closeCommander: Option[ActorRef], closedEvent: ConnectionClosed): Receive = { - case SuspendReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doRead(handler, closeCommander) + def closingWithPendingWrite(info: ConnectionInfo, closeCommander: Option[ActorRef], + closedEvent: ConnectionClosed): Receive = { + case SuspendReading ⇒ info.registration.disableInterest(OP_READ) + case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case ChannelReadable ⇒ doRead(info, closeCommander) case ChannelWritable ⇒ - doWrite(handler) + doWrite(info) if (!writePending) // writing is now finished - handleClose(handler, closeCommander, closedEvent) - case SendBufferFull(remaining) ⇒ { pendingWrite = remaining; selector ! WriteInterest } - case WriteFileFinished ⇒ { pendingWrite = null; handleClose(handler, closeCommander, closedEvent) } - case WriteFileFailed(e) ⇒ handleError(handler, e) // rethrow exception from dispatcher task + handleClose(info, closeCommander, closedEvent) + case SendBufferFull(remaining) ⇒ { pendingWrite = remaining; info.registration.enableInterest(OP_WRITE) } + case WriteFileFinished ⇒ { pendingWrite = null; handleClose(info, closeCommander, closedEvent) } + case WriteFileFailed(e) ⇒ handleError(info.handler, e) // rethrow exception from dispatcher task - case Abort ⇒ handleClose(handler, Some(sender), Aborted) + case Abort ⇒ handleClose(info, Some(sender), Aborted) } /** connection is closed on our side and we're waiting from confirmation from the other side */ - def closing(handler: ActorRef, closeCommander: Option[ActorRef]): Receive = { - case SuspendReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doRead(handler, closeCommander) - case Abort ⇒ handleClose(handler, Some(sender), Aborted) + def closing(info: ConnectionInfo, closeCommander: Option[ActorRef]): Receive = { + case SuspendReading ⇒ info.registration.disableInterest(OP_READ) + case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case ChannelReadable ⇒ doRead(info, closeCommander) + case Abort ⇒ handleClose(info, Some(sender), Aborted) } - private[this] var useResumeWriting = false - private[this] var writingSuspended = false - private[this] var interestedInResume: Option[ActorRef] = None - - def handleWriteMessages(handler: ActorRef): Receive = { + def handleWriteMessages(info: ConnectionInfo): Receive = { case ChannelWritable ⇒ if (writePending) { - doWrite(handler) + doWrite(info) if (!writePending && interestedInResume.nonEmpty) { interestedInResume.get ! WritingResumed interestedInResume = None @@ -133,7 +128,7 @@ private[io] abstract class TcpConnection( } else if (writePending) { if (TraceLogging) log.debug("Dropping write because queue is full") sender ! write.failureMessage - if (useResumeWriting) writingSuspended = true + if (info.useResumeWriting) writingSuspended = true } else write match { case Write(data, ack) if data.isEmpty ⇒ @@ -141,7 +136,7 @@ private[io] abstract class TcpConnection( case _ ⇒ pendingWrite = createWrite(write) - doWrite(handler) + doWrite(info) } case ResumeWriting ⇒ @@ -161,15 +156,16 @@ private[io] abstract class TcpConnection( else sender ! CommandFailed(ResumeWriting) } else sender ! WritingResumed - case SendBufferFull(remaining) ⇒ { pendingWrite = remaining; selector ! WriteInterest } + case SendBufferFull(remaining) ⇒ { pendingWrite = remaining; info.registration.enableInterest(OP_WRITE) } case WriteFileFinished ⇒ pendingWrite = null - case WriteFileFailed(e) ⇒ handleError(handler, e) // rethrow exception from dispatcher task + case WriteFileFailed(e) ⇒ handleError(info.handler, e) // rethrow exception from dispatcher task } // AUXILIARIES and IMPLEMENTATION /** used in subclasses to start the common machinery above once a channel is connected */ - def completeConnect(commander: ActorRef, options: immutable.Traversable[SocketOption]): Unit = { + def completeConnect(registration: ChannelRegistration, commander: ActorRef, + options: immutable.Traversable[SocketOption]): Unit = { // Turn off Nagle's algorithm by default channel.socket.setTcpNoDelay(true) options.foreach(_.afterConnect(channel.socket)) @@ -179,10 +175,10 @@ private[io] abstract class TcpConnection( channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]) context.setReceiveTimeout(RegisterTimeout) - context.become(waitingForRegistration(commander)) + context.become(waitingForRegistration(registration, commander)) } - def doRead(handler: ActorRef, closeCommander: Option[ActorRef]): Unit = { + def doRead(info: ConnectionInfo, closeCommander: Option[ActorRef]): Unit = { @tailrec def innerRead(buffer: ByteBuffer, remainingLimit: Int): ReadResult = if (remainingLimit > 0) { // never read more than the configured limit @@ -193,7 +189,7 @@ private[io] abstract class TcpConnection( buffer.flip() if (TraceLogging) log.debug("Read [{}] bytes.", readBytes) - if (readBytes > 0) handler ! Received(ByteString(buffer)) + if (readBytes > 0) info.handler ! Received(ByteString(buffer)) readBytes match { case `maxBufferSpace` ⇒ innerRead(buffer, remainingLimit - maxBufferSpace) @@ -206,65 +202,63 @@ private[io] abstract class TcpConnection( val buffer = bufferPool.acquire() try innerRead(buffer, ReceivedMessageSizeLimit) match { - case AllRead ⇒ selector ! ReadInterest + case AllRead ⇒ info.registration.enableInterest(OP_READ) case MoreDataWaiting ⇒ self ! ChannelReadable case EndOfStream if channel.socket.isOutputShutdown ⇒ if (TraceLogging) log.debug("Read returned end-of-stream, our side already closed") - doCloseConnection(handler, closeCommander, ConfirmedClosed) + doCloseConnection(info.handler, closeCommander, ConfirmedClosed) case EndOfStream ⇒ if (TraceLogging) log.debug("Read returned end-of-stream, our side not yet closed") - handleClose(handler, closeCommander, PeerClosed) + handleClose(info, closeCommander, PeerClosed) } catch { - case e: IOException ⇒ handleError(handler, e) + case e: IOException ⇒ handleError(info.handler, e) } finally bufferPool.release(buffer) } - def doWrite(handler: ActorRef): Unit = - pendingWrite = pendingWrite.doWrite(handler) + def doWrite(info: ConnectionInfo): Unit = pendingWrite = pendingWrite.doWrite(info) def closeReason = if (channel.socket.isOutputShutdown) ConfirmedClosed else PeerClosed - def handleClose(handler: ActorRef, closeCommander: Option[ActorRef], closedEvent: ConnectionClosed): Unit = closedEvent match { + def handleClose(info: ConnectionInfo, closeCommander: Option[ActorRef], + closedEvent: ConnectionClosed): Unit = closedEvent match { case Aborted ⇒ if (TraceLogging) log.debug("Got Abort command. RESETing connection.") - doCloseConnection(handler, closeCommander, closedEvent) - case PeerClosed if keepOpenOnPeerClosed ⇒ + doCloseConnection(info.handler, closeCommander, closedEvent) + case PeerClosed if info.keepOpenOnPeerClosed ⇒ // report that peer closed the connection - handler ! PeerClosed + info.handler ! PeerClosed // used to check if peer already closed its side later peerClosed = true - context.become(peerSentEOF(handler)) + context.become(peerSentEOF(info)) case _ if writePending ⇒ // finish writing first if (TraceLogging) log.debug("Got Close command but write is still pending.") - context.become(closingWithPendingWrite(handler, closeCommander, closedEvent)) + context.become(closingWithPendingWrite(info, closeCommander, closedEvent)) case ConfirmedClosed ⇒ // shutdown output and wait for confirmation if (TraceLogging) log.debug("Got ConfirmedClose command, sending FIN.") channel.socket.shutdownOutput() if (peerClosed) // if peer closed first, the socket is now fully closed - doCloseConnection(handler, closeCommander, closedEvent) - else context.become(closing(handler, closeCommander)) + doCloseConnection(info.handler, closeCommander, closedEvent) + else context.become(closing(info, closeCommander)) case _ ⇒ // close now if (TraceLogging) log.debug("Got Close command, closing connection.") - doCloseConnection(handler, closeCommander, closedEvent) + doCloseConnection(info.handler, closeCommander, closedEvent) } def doCloseConnection(handler: ActorRef, closeCommander: Option[ActorRef], closedEvent: ConnectionClosed): Unit = { if (closedEvent == Aborted) abort() else channel.close() - closedMessage = CloseInformation(Set(handler) ++ closeCommander, closedEvent) - context.stop(self) } def handleError(handler: ActorRef, exception: IOException): Nothing = { closedMessage = CloseInformation(Set(handler), ErrorClosed(extractMsg(exception))) - throw exception } + @tailrec private[this] def extractMsg(t: Throwable): String = if (t == null) "unknown" else { @@ -330,7 +324,7 @@ private[io] abstract class TcpConnection( def release(): Unit = bufferPool.release(buffer) - def doWrite(handler: ActorRef): PendingWrite = { + def doWrite(info: ConnectionInfo): PendingWrite = { @tailrec def innerWrite(pendingWrite: PendingBufferWrite): PendingWrite = { val toWrite = pendingWrite.buffer.remaining() require(toWrite != 0) @@ -342,7 +336,7 @@ private[io] abstract class TcpConnection( if (pendingWrite.hasData) if (writtenBytes == toWrite) innerWrite(nextWrite) // wrote complete buffer, try again now else { - selector ! WriteInterest + info.registration.enableInterest(OP_WRITE) nextWrite } // try again later else { // everything written @@ -355,7 +349,7 @@ private[io] abstract class TcpConnection( } try innerWrite(this) - catch { case e: IOException ⇒ handleError(handler, e) } + catch { case e: IOException ⇒ handleError(info.handler, e) } } def hasData = buffer.hasRemaining || remainingData.nonEmpty def consume(writtenBytes: Int): PendingBufferWrite = @@ -374,7 +368,7 @@ private[io] abstract class TcpConnection( fileChannel: FileChannel, alreadyWritten: Long) extends PendingWrite { - def doWrite(handler: ActorRef): PendingWrite = { + def doWrite(info: ConnectionInfo): PendingWrite = { tcp.fileIoDispatcher.execute(writeFileRunnable(this)) this } @@ -392,6 +386,7 @@ private[io] abstract class TcpConnection( def remainingBytes = write.count - alreadyWritten def currentPosition = write.position + alreadyWritten } + private[io] def writeFileRunnable(pendingWrite: PendingWriteFile): Runnable = new Runnable { def run(): Unit = try { @@ -425,9 +420,15 @@ private[io] object TcpConnection { * Used to transport information to the postStop method to notify * interested party about a connection close. */ - case class CloseInformation( - notificationsTo: Set[ActorRef], - closedEvent: Event) + case class CloseInformation(notificationsTo: Set[ActorRef], closedEvent: Event) + + /** + * Groups required connection-related data that are only available once the connection has been fully established. + */ + case class ConnectionInfo(registration: ChannelRegistration, + handler: ActorRef, + keepOpenOnPeerClosed: Boolean, + useResumeWriting: Boolean) // INTERNAL MESSAGES @@ -444,7 +445,7 @@ private[io] object TcpConnection { def ack: Any def wantsAck = !ack.isInstanceOf[NoAck] - def doWrite(handler: ActorRef): PendingWrite + def doWrite(info: ConnectionInfo): PendingWrite /** Release any open resources */ def release(): Unit diff --git a/akka-actor/src/main/scala/akka/io/TcpIncomingConnection.scala b/akka-actor/src/main/scala/akka/io/TcpIncomingConnection.scala index e1bcc0e399..a63163327a 100644 --- a/akka-actor/src/main/scala/akka/io/TcpIncomingConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpIncomingConnection.scala @@ -8,7 +8,6 @@ import java.nio.channels.SocketChannel import scala.collection.immutable import akka.actor.ActorRef import akka.io.Inet.SocketOption -import akka.io.SelectionHandler.{ ChannelRegistered, RegisterChannel } /** * An actor handling the connection state machine for an incoming, already connected @@ -16,17 +15,18 @@ import akka.io.SelectionHandler.{ ChannelRegistered, RegisterChannel } * * INTERNAL API */ -private[io] class TcpIncomingConnection(_channel: SocketChannel, - _tcp: TcpExt, - handler: ActorRef, +private[io] class TcpIncomingConnection(_tcp: TcpExt, + _channel: SocketChannel, + registry: ChannelRegistry, + bindHandler: ActorRef, options: immutable.Traversable[SocketOption]) - extends TcpConnection(_channel, _tcp) { + extends TcpConnection(_tcp, _channel) { - context.watch(handler) // sign death pact + context.watch(bindHandler) // sign death pact - context.parent ! RegisterChannel(channel, 0) + registry.register(channel, initialOps = 0) def receive = { - case ChannelRegistered ⇒ completeConnect(handler, options) + case registration: ChannelRegistration ⇒ completeConnect(registration, bindHandler, options) } } diff --git a/akka-actor/src/main/scala/akka/io/TcpListener.scala b/akka-actor/src/main/scala/akka/io/TcpListener.scala index 35db90fa25..946c63c07e 100644 --- a/akka-actor/src/main/scala/akka/io/TcpListener.scala +++ b/akka-actor/src/main/scala/akka/io/TcpListener.scala @@ -5,13 +5,13 @@ package akka.io import java.nio.channels.{ SocketChannel, SelectionKey, ServerSocketChannel } +import java.net.InetSocketAddress import scala.annotation.tailrec import scala.util.control.NonFatal import akka.actor.{ Props, ActorLogging, ActorRef, Actor } import akka.io.SelectionHandler._ import akka.io.Tcp._ import akka.io.IO.HasFailureMessage -import java.net.InetSocketAddress import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } /** @@ -30,12 +30,13 @@ private[io] object TcpListener { /** * INTERNAL API */ -private[io] class TcpListener( - val selectorRouter: ActorRef, - val tcp: TcpExt, - val bindCommander: ActorRef, - val bind: Bind) +private[io] class TcpListener(selectorRouter: ActorRef, + tcp: TcpExt, + channelRegistry: ChannelRegistry, + bindCommander: ActorRef, + bind: Bind) extends Actor with ActorLogging with RequiresMessageQueue[UnboundedMessageQueueSemantics] { + import TcpListener._ import tcp.Settings._ @@ -53,7 +54,7 @@ private[io] class TcpListener( case isa: InetSocketAddress ⇒ isa case x ⇒ throw new IllegalArgumentException(s"bound to unknown SocketAddress [$x]") } - context.parent ! RegisterChannel(channel, SelectionKey.OP_ACCEPT) + channelRegistry.register(channel, SelectionKey.OP_ACCEPT) log.debug("Successfully bound to {}", ret) ret } catch { @@ -66,14 +67,14 @@ private[io] class TcpListener( override def supervisorStrategy = IO.connectionSupervisorStrategy def receive: Receive = { - case ChannelRegistered ⇒ + case registration: ChannelRegistration ⇒ bindCommander ! Bound(channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]) - context.become(bound) + context.become(bound(registration)) } - def bound: Receive = { + def bound(registration: ChannelRegistration): Receive = { case ChannelAcceptable ⇒ - acceptAllPending(BatchAcceptLimit) + acceptAllPending(registration, BatchAcceptLimit) case FailedRegisterIncoming(socketChannel) ⇒ log.warning("Could not register incoming connection since selector capacity limit is reached, closing connection") @@ -90,7 +91,7 @@ private[io] class TcpListener( context.stop(self) } - @tailrec final def acceptAllPending(limit: Int): Unit = { + @tailrec final def acceptAllPending(registration: ChannelRegistration, limit: Int): Unit = { val socketChannel = if (limit > 0) { try channel.accept() @@ -101,9 +102,11 @@ private[io] class TcpListener( if (socketChannel != null) { log.debug("New connection accepted") socketChannel.configureBlocking(false) - selectorRouter ! WorkerForCommand(RegisterIncoming(socketChannel), self, Props(classOf[TcpIncomingConnection], socketChannel, tcp, bind.handler, bind.options)) - acceptAllPending(limit - 1) - } else context.parent ! AcceptInterest + def props(registry: ChannelRegistry) = + Props(classOf[TcpIncomingConnection], tcp, socketChannel, registry, bind.handler, bind.options) + selectorRouter ! WorkerForCommand(RegisterIncoming(socketChannel), self, props) + acceptAllPending(registration, limit - 1) + } else registration.enableInterest(SelectionKey.OP_ACCEPT) } override def postStop() { @@ -116,5 +119,4 @@ private[io] class TcpListener( case NonFatal(e) ⇒ log.error(e, "Error closing ServerSocketChannel") } } - } diff --git a/akka-actor/src/main/scala/akka/io/TcpManager.scala b/akka-actor/src/main/scala/akka/io/TcpManager.scala index 05dde136bb..f59b78b369 100644 --- a/akka-actor/src/main/scala/akka/io/TcpManager.scala +++ b/akka-actor/src/main/scala/akka/io/TcpManager.scala @@ -48,8 +48,13 @@ import akka.io.IO.SelectorBasedManager private[io] class TcpManager(tcp: TcpExt) extends SelectorBasedManager(tcp.Settings, tcp.Settings.NrOfSelectors) with ActorLogging { def receive = workerForCommandHandler { - case c: Connect ⇒ Props(classOf[TcpOutgoingConnection], tcp, sender, c) - case b: Bind ⇒ Props(classOf[TcpListener], selectorPool, tcp, sender, b) + case c: Connect ⇒ + val commander = sender + registry ⇒ Props(classOf[TcpOutgoingConnection], tcp, registry, commander, c) + + case b: Bind ⇒ + val commander = sender + registry ⇒ Props(classOf[TcpListener], selectorPool, tcp, registry, commander, b) } } diff --git a/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala b/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala index bbbe6cfc30..2e10faefa1 100644 --- a/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala @@ -19,9 +19,10 @@ import scala.collection.immutable * INTERNAL API */ private[io] class TcpOutgoingConnection(_tcp: TcpExt, + channelRegistry: ChannelRegistry, commander: ActorRef, connect: Connect) - extends TcpConnection(TcpOutgoingConnection.newSocketChannel(), _tcp) { + extends TcpConnection(_tcp, TcpOutgoingConnection.newSocketChannel()) { import connect._ @@ -29,25 +30,25 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt, localAddress.foreach(channel.socket.bind) options.foreach(_.beforeConnect(channel.socket)) - selector ! RegisterChannel(channel, SelectionKey.OP_CONNECT) + channelRegistry.register(channel, SelectionKey.OP_CONNECT) def receive: Receive = { - case ChannelRegistered ⇒ + case registration: ChannelRegistration ⇒ log.debug("Attempting connection to [{}]", remoteAddress) if (channel.connect(remoteAddress)) - completeConnect(commander, options) - else { - context.become(connecting(commander, options)) - } + completeConnect(registration, commander, options) + else + context.become(connecting(registration, commander, options)) } - def connecting(commander: ActorRef, options: immutable.Traversable[SocketOption]): Receive = { + def connecting(registration: ChannelRegistration, commander: ActorRef, + options: immutable.Traversable[SocketOption]): Receive = { case ChannelConnectable ⇒ try { val connected = channel.finishConnect() assert(connected, "Connectable channel failed to connect") log.debug("Connection established") - completeConnect(commander, options) + completeConnect(registration, commander, options) } catch { case e: IOException ⇒ if (tcp.Settings.TraceLogging) log.debug("Could not establish connection due to {}", e) @@ -55,7 +56,6 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt, throw e } } - } /** diff --git a/akka-actor/src/main/scala/akka/io/UdpConnectedManager.scala b/akka-actor/src/main/scala/akka/io/UdpConnectedManager.scala index c6e3cf1e8a..cb2f6126cf 100644 --- a/akka-actor/src/main/scala/akka/io/UdpConnectedManager.scala +++ b/akka-actor/src/main/scala/akka/io/UdpConnectedManager.scala @@ -10,12 +10,13 @@ import akka.io.UdpConnected.Connect /** * INTERNAL API */ -private[io] class UdpConnectedManager(udpConn: UdpConnectedExt) extends SelectorBasedManager(udpConn.settings, udpConn.settings.NrOfSelectors) { +private[io] class UdpConnectedManager(udpConn: UdpConnectedExt) + extends SelectorBasedManager(udpConn.settings, udpConn.settings.NrOfSelectors) { def receive = workerForCommandHandler { case c: Connect ⇒ val commander = sender - Props(classOf[UdpConnection], udpConn, commander, c) + registry ⇒ Props(classOf[UdpConnection], udpConn, registry, commander, c) } } diff --git a/akka-actor/src/main/scala/akka/io/UdpConnection.scala b/akka-actor/src/main/scala/akka/io/UdpConnection.scala index 5407d16bd7..59bacc495d 100644 --- a/akka-actor/src/main/scala/akka/io/UdpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/UdpConnection.scala @@ -3,28 +3,26 @@ */ package akka.io -import akka.actor.{ Actor, ActorLogging, ActorRef } -import akka.io.SelectionHandler._ -import akka.io.UdpConnected._ -import akka.util.ByteString import java.nio.ByteBuffer import java.nio.channels.DatagramChannel import java.nio.channels.SelectionKey._ import scala.annotation.tailrec import scala.util.control.NonFatal +import akka.actor.{ Actor, ActorLogging, ActorRef } import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } +import akka.util.ByteString +import akka.io.SelectionHandler._ +import akka.io.UdpConnected._ /** * INTERNAL API */ -private[io] class UdpConnection( - val udpConn: UdpConnectedExt, - val commander: ActorRef, - val connect: Connect) +private[io] class UdpConnection(udpConn: UdpConnectedExt, + channelRegistry: ChannelRegistry, + commander: ActorRef, + connect: Connect) extends Actor with ActorLogging with RequiresMessageQueue[UnboundedMessageQueueSemantics] { - def selector: ActorRef = context.parent - import connect._ import udpConn._ import udpConn.settings._ @@ -50,19 +48,19 @@ private[io] class UdpConnection( } datagramChannel } - selector ! RegisterChannel(channel, OP_READ) + channelRegistry.register(channel, OP_READ) log.debug("Successfully connected to [{}]", remoteAddress) def receive = { - case ChannelRegistered ⇒ + case registration: ChannelRegistration ⇒ commander ! Connected - context.become(connected, discardOld = true) + context.become(connected(registration), discardOld = true) } - def connected: Receive = { - case StopReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doRead(handler) + def connected(registration: ChannelRegistration): Receive = { + case StopReading ⇒ registration.disableInterest(OP_READ) + case ResumeReading ⇒ registration.enableInterest(OP_READ) + case ChannelReadable ⇒ doRead(registration, handler) case Close ⇒ log.debug("Closing UDP connection to [{}]", remoteAddress) @@ -81,12 +79,12 @@ private[io] class UdpConnection( case send: Send ⇒ pendingSend = (send, sender) - selector ! WriteInterest + registration.enableInterest(OP_WRITE) case ChannelWritable ⇒ doWrite() } - def doRead(handler: ActorRef): Unit = { + def doRead(registration: ChannelRegistration, handler: ActorRef): Unit = { @tailrec def innerRead(readsLeft: Int, buffer: ByteBuffer): Unit = { buffer.clear() buffer.limit(DirectBufferSize) @@ -99,13 +97,12 @@ private[io] class UdpConnection( } val buffer = bufferPool.acquire() try innerRead(BatchReceiveLimit, buffer) finally { - selector ! ReadInterest + registration.enableInterest(OP_READ) bufferPool.release(buffer) } } final def doWrite(): Unit = { - val buffer = udpConn.bufferPool.acquire() try { val (send, commander) = pendingSend @@ -122,10 +119,9 @@ private[io] class UdpConnection( udpConn.bufferPool.release(buffer) pendingSend = null } - } - override def postStop() { + override def postStop(): Unit = if (channel.isOpen) { log.debug("Closing DatagramChannel after being stopped") try channel.close() @@ -133,6 +129,4 @@ private[io] class UdpConnection( case NonFatal(e) ⇒ log.error(e, "Error closing DatagramChannel") } } - } - } diff --git a/akka-actor/src/main/scala/akka/io/UdpListener.scala b/akka-actor/src/main/scala/akka/io/UdpListener.scala index ef3cf9b3ee..667958e5b7 100644 --- a/akka-actor/src/main/scala/akka/io/UdpListener.scala +++ b/akka-actor/src/main/scala/akka/io/UdpListener.scala @@ -3,25 +3,25 @@ */ package akka.io -import akka.actor.{ ActorLogging, Actor, ActorRef } -import akka.io.SelectionHandler._ -import akka.io.Udp._ -import akka.util.ByteString import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.DatagramChannel import java.nio.channels.SelectionKey._ import scala.annotation.tailrec import scala.util.control.NonFatal -import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } +import akka.actor.{ ActorLogging, Actor, ActorRef } +import akka.dispatch.{ RequiresMessageQueue, UnboundedMessageQueueSemantics } +import akka.util.ByteString +import akka.io.SelectionHandler._ +import akka.io.Udp._ /** * INTERNAL API */ -private[io] class UdpListener( - val udp: UdpExt, - val bindCommander: ActorRef, - val bind: Bind) +private[io] class UdpListener(val udp: UdpExt, + channelRegistry: ChannelRegistry, + bindCommander: ActorRef, + bind: Bind) extends Actor with ActorLogging with WithUdpSend with RequiresMessageQueue[UnboundedMessageQueueSemantics] { import udp.bufferPool @@ -43,7 +43,7 @@ private[io] class UdpListener( case isa: InetSocketAddress ⇒ isa case x ⇒ throw new IllegalArgumentException(s"bound to unknown SocketAddress [$x]") } - context.parent ! RegisterChannel(channel, OP_READ) + channelRegistry.register(channel, OP_READ) log.debug("Successfully bound to [{}]", ret) ret } catch { @@ -54,15 +54,15 @@ private[io] class UdpListener( } def receive: Receive = { - case ChannelRegistered ⇒ + case registration: ChannelRegistration ⇒ bindCommander ! Bound(channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]) - context.become(readHandlers orElse sendHandlers, discardOld = true) + context.become(readHandlers(registration) orElse sendHandlers(registration), discardOld = true) } - def readHandlers: Receive = { - case StopReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doReceive(bind.handler) + def readHandlers(registration: ChannelRegistration): Receive = { + case StopReading ⇒ registration.disableInterest(OP_READ) + case ResumeReading ⇒ registration.enableInterest(OP_READ) + case ChannelReadable ⇒ doReceive(registration, bind.handler) case Unbind ⇒ log.debug("Unbinding endpoint [{}]", bind.localAddress) @@ -73,7 +73,7 @@ private[io] class UdpListener( } finally context.stop(self) } - def doReceive(handler: ActorRef): Unit = { + def doReceive(registration: ChannelRegistration, handler: ActorRef): Unit = { @tailrec def innerReceive(readsLeft: Int, buffer: ByteBuffer) { buffer.clear() buffer.limit(DirectBufferSize) @@ -90,11 +90,11 @@ private[io] class UdpListener( val buffer = bufferPool.acquire() try innerReceive(BatchReceiveLimit, buffer) finally { bufferPool.release(buffer) - selector ! ReadInterest + registration.enableInterest(OP_READ) } } - override def postStop() { + override def postStop(): Unit = { if (channel.isOpen) { log.debug("Closing DatagramChannel after being stopped") try channel.close() diff --git a/akka-actor/src/main/scala/akka/io/UdpManager.scala b/akka-actor/src/main/scala/akka/io/UdpManager.scala index 73f7ba95e9..a9eae918fa 100644 --- a/akka-actor/src/main/scala/akka/io/UdpManager.scala +++ b/akka-actor/src/main/scala/akka/io/UdpManager.scala @@ -49,10 +49,11 @@ private[io] class UdpManager(udp: UdpExt) extends SelectorBasedManager(udp.setti def receive = workerForCommandHandler { case b: Bind ⇒ val commander = sender - Props(classOf[UdpListener], udp, commander, b) + registry ⇒ Props(classOf[UdpListener], udp, registry, commander, b) + case SimpleSender(options) ⇒ val commander = sender - Props(classOf[UdpSender], udp, options, commander) + registry ⇒ Props(classOf[UdpSender], udp, registry, commander, options) } } diff --git a/akka-actor/src/main/scala/akka/io/UdpSender.scala b/akka-actor/src/main/scala/akka/io/UdpSender.scala index 74450c2991..921414e6e9 100644 --- a/akka-actor/src/main/scala/akka/io/UdpSender.scala +++ b/akka-actor/src/main/scala/akka/io/UdpSender.scala @@ -3,26 +3,23 @@ */ package akka.io -import akka.actor._ import java.nio.channels.DatagramChannel -import akka.io.Udp._ -import akka.io.SelectionHandler.{ ChannelRegistered, RegisterChannel } import scala.collection.immutable -import akka.io.Inet.SocketOption import scala.util.control.NonFatal -import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue } +import akka.dispatch.{ RequiresMessageQueue, UnboundedMessageQueueSemantics } +import akka.io.Inet.SocketOption +import akka.io.Udp._ +import akka.actor._ /** * INTERNAL API */ -private[io] class UdpSender( - val udp: UdpExt, - options: immutable.Traversable[SocketOption], - val commander: ActorRef) +private[io] class UdpSender(val udp: UdpExt, + channelRegistry: ChannelRegistry, + commander: ActorRef, + options: immutable.Traversable[SocketOption]) extends Actor with ActorLogging with WithUdpSend with RequiresMessageQueue[UnboundedMessageQueueSemantics] { - def selector: ActorRef = context.parent - val channel = { val datagramChannel = DatagramChannel.open datagramChannel.configureBlocking(false) @@ -32,11 +29,11 @@ private[io] class UdpSender( datagramChannel } - selector ! RegisterChannel(channel, 0) + channelRegistry.register(channel, initialOps = 0) def receive: Receive = { - case ChannelRegistered ⇒ - context.become(sendHandlers, discardOld = true) + case registration: ChannelRegistration ⇒ + context.become(sendHandlers(registration), discardOld = true) commander ! SimpleSendReady } diff --git a/akka-actor/src/main/scala/akka/io/WithUdpSend.scala b/akka-actor/src/main/scala/akka/io/WithUdpSend.scala index df39efa59c..aadfedd527 100644 --- a/akka-actor/src/main/scala/akka/io/WithUdpSend.scala +++ b/akka-actor/src/main/scala/akka/io/WithUdpSend.scala @@ -3,10 +3,10 @@ */ package akka.io +import java.nio.channels.{ SelectionKey, DatagramChannel } import akka.actor.{ ActorRef, ActorLogging, Actor } import akka.io.Udp.{ CommandFailed, Send } import akka.io.SelectionHandler._ -import java.nio.channels.DatagramChannel /** * INTERNAL API @@ -21,15 +21,13 @@ private[io] trait WithUdpSend { var retriedSend = false def hasWritePending = pendingSend ne null - def selector: ActorRef def channel: DatagramChannel def udp: UdpExt val settings = udp.settings import settings._ - def sendHandlers: Receive = { - + def sendHandlers(registration: ChannelRegistration): Receive = { case send: Send if hasWritePending ⇒ if (TraceLogging) log.debug("Dropping write because queue is full") sender ! CommandFailed(send) @@ -41,14 +39,12 @@ private[io] trait WithUdpSend { case send: Send ⇒ pendingSend = send pendingCommander = sender - doSend() - - case ChannelWritable ⇒ if (hasWritePending) doSend() + doSend(registration) + case ChannelWritable ⇒ if (hasWritePending) doSend(registration) } - final def doSend(): Unit = { - + final def doSend(registration: ChannelRegistration): Unit = { val buffer = udp.bufferPool.acquire() try { buffer.clear() @@ -65,7 +61,7 @@ private[io] trait WithUdpSend { pendingSend = null pendingCommander = null } else { - selector ! WriteInterest + registration.enableInterest(SelectionKey.OP_WRITE) retriedSend = true } } else { @@ -74,10 +70,8 @@ private[io] trait WithUdpSend { pendingSend = null pendingCommander = null } - } finally { udp.bufferPool.release(buffer) } - } }