From be9abae1e32fda97fa15713ebf83f7b2d3c5f25e Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 15 Jan 2013 18:08:45 +0100 Subject: [PATCH] tcp connection actors, see #2886 --- akka-io/src/main/resources/reference.conf | 11 +- akka-io/src/main/scala/akka/io/Tcp.scala | 87 ++-- .../main/scala/akka/io/TcpConnection.scala | 225 ++++++++++ .../scala/akka/io/TcpIncomingConnection.scala | 26 ++ .../src/main/scala/akka/io/TcpManager.scala | 4 +- .../scala/akka/io/TcpOutgoingConnection.scala | 51 +++ .../akka/io/ThreadLocalDirectBuffer.scala | 32 ++ .../scala/akka/io/TcpConnectionSpec.scala | 414 ++++++++++++++++++ 8 files changed, 804 insertions(+), 46 deletions(-) create mode 100644 akka-io/src/main/scala/akka/io/TcpConnection.scala create mode 100644 akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala create mode 100644 akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala create mode 100644 akka-io/src/main/scala/akka/io/ThreadLocalDirectBuffer.scala create mode 100644 akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala diff --git a/akka-io/src/main/resources/reference.conf b/akka-io/src/main/resources/reference.conf index 1b502d4c4f..ebcd1fa2e9 100644 --- a/akka-io/src/main/resources/reference.conf +++ b/akka-io/src/main/resources/reference.conf @@ -54,8 +54,17 @@ akka { # Fully qualified config path which holds the dispatcher configuration # for the selector management actors management-dispatcher = "akka.actor.default-dispatcher" + + # The size of the thread-local direct buffers used to read or write + # network data from the kernel. Those buffer directly add to the footprint + # of the threads from the dispatcher tcp connection actors are using. + direct-buffer-size = 524288 + + # The duration a connection actor waits for a `Register` message from + # its commander before aborting the connection. + register-timeout = 5s } } -} \ No newline at end of file +} diff --git a/akka-io/src/main/scala/akka/io/Tcp.scala b/akka-io/src/main/scala/akka/io/Tcp.scala index cd9859c2d1..9d7e754b98 100644 --- a/akka-io/src/main/scala/akka/io/Tcp.scala +++ b/akka-io/src/main/scala/akka/io/Tcp.scala @@ -31,31 +31,30 @@ object Tcp extends ExtensionKey[TcpExt] { case class Bind(handler: ActorRef, address: InetSocketAddress, backlog: Int = 100, - options: immutable.Seq[SO.SocketOption] = Nil) extends Command + options: immutable.Seq[SocketOption] = Nil) extends Command case object Unbind extends Command case class Register(handler: ActorRef) extends Command - object SO { + /** + * SocketOption is a package of data (from the user) and associated + * behavior (how to apply that to a socket). + */ + sealed trait SocketOption { /** - * SocketOption is a package of data (from the user) and associated - * behavior (how to apply that to a socket). + * Action to be taken for this option before calling bind() */ - sealed trait SocketOption { - /** - * Action to be taken for this option before calling bind() - */ - def beforeBind(s: ServerSocket): Unit = () - /** - * Action to be taken for this option before calling connect() - */ - def beforeConnect(s: Socket): Unit = () - /** - * Action to be taken for this option after connect returned (i.e. on - * the slave socket for servers). - */ - def afterConnect(s: Socket): Unit = () - } - + def beforeBind(s: ServerSocket): Unit = () + /** + * Action to be taken for this option before calling connect() + */ + def beforeConnect(s: Socket): Unit = () + /** + * Action to be taken for this option after connect returned (i.e. on + * the slave socket for servers). + */ + def afterConnect(s: Socket): Unit = () + } + object SO { // shared socket options /** @@ -109,7 +108,7 @@ object Tcp extends ExtensionKey[TcpExt] { * For more information see [[java.net.Socket.setSendBufferSize]] */ case class SendBufferSize(size: Int) extends SocketOption { - require(size > 0, "ReceiveBufferSize must be > 0") + require(size > 0, "SendBufferSize must be > 0") override def afterConnect(s: Socket): Unit = s.setSendBufferSize(size) } @@ -139,39 +138,37 @@ object Tcp extends ExtensionKey[TcpExt] { } // TODO: what about close reasons? - case object Close extends Command - case object ConfirmedClose extends Command - case object Abort extends Command + sealed trait CloseCommand extends Command - trait Write extends Command { - def data: ByteString - def ack: AnyRef - def nack: AnyRef - } + case object Close extends CloseCommand + case object ConfirmedClose extends CloseCommand + case object Abort extends CloseCommand + + case class Write(data: ByteString, ack: AnyRef) extends Command object Write { - def apply(_data: ByteString): Write = new Write { - def data: ByteString = _data - def ack: AnyRef = null - def nack: AnyRef = null - } + val Empty: Write = Write(ByteString.empty, null) + def apply(data: ByteString): Write = + if (data.isEmpty) Empty else Write(data, null) } + case object StopReading extends Command case object ResumeReading extends Command /// EVENTS sealed trait Event - case object Bound extends Event - case class Received(data: ByteString) extends Event - case class Connected(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) extends Event + case class Connected(remoteAddress: InetSocketAddress, localAddress: InetSocketAddress) extends Event case class CommandFailed(cmd: Command) extends Event + case object Bound extends Event + case object Unbound extends Event - sealed trait Closed extends Event - case object PeerClosed extends Closed - case object ActivelyClosed extends Closed - case object ConfirmedClosed extends Closed - case class Error(cause: Throwable) extends Closed + sealed trait ConnectionClosed extends Event + case object Closed extends ConnectionClosed + case object Aborted extends ConnectionClosed + case object ConfirmedClosed extends ConnectionClosed + case object PeerClosed extends ConnectionClosed + case class ErrorClose(cause: Throwable) extends ConnectionClosed /// INTERNAL case class RegisterClientChannel(channel: SocketChannel) @@ -208,9 +205,13 @@ class TcpExt(system: ExtendedActorSystem) extends IO.Extension { val SelectorDispatcher = getString("selector-dispatcher") val WorkerDispatcher = getString("worker-dispatcher") val ManagementDispatcher = getString("management-dispatcher") + val DirectBufferSize = getInt("direct-buffer-size") + val RegisterTimeout = + if (getString("register-timeout") == "infinite") Duration.Undefined + else Duration(getMilliseconds("register-timeout"), MILLISECONDS) } val manager = system.asInstanceOf[ActorSystemImpl].systemActorOf( Props[TcpManager].withDispatcher(Settings.ManagementDispatcher), "IO-TCP") -} \ No newline at end of file +} diff --git a/akka-io/src/main/scala/akka/io/TcpConnection.scala b/akka-io/src/main/scala/akka/io/TcpConnection.scala new file mode 100644 index 0000000000..707e1664de --- /dev/null +++ b/akka-io/src/main/scala/akka/io/TcpConnection.scala @@ -0,0 +1,225 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ + +package akka.io + +import java.net.InetSocketAddress +import java.io.IOException +import java.nio.channels.SocketChannel +import scala.util.control.NonFatal +import scala.collection.immutable +import scala.concurrent.duration._ +import akka.actor._ +import akka.util.ByteString +import Tcp._ + +/** + * Base class for TcpIncomingConnection and TcpOutgoingConnection. + */ +abstract class TcpConnection(val selector: ActorRef, + val channel: SocketChannel) extends Actor with ThreadLocalDirectBuffer with ActorLogging { + + channel.configureBlocking(false) + + var pendingWrite: Write = Write.Empty // a write "queue" of size 1 for holding one unfinished write command + def writePending = pendingWrite ne Write.Empty + + def registerTimeout = Tcp(context.system).Settings.RegisterTimeout + + // STATES + + /** connection established, waiting for registration from user handler */ + def waitingForRegistration(commander: ActorRef): Receive = { + case Register(handler) ⇒ + log.debug("{} registered as connection handler", handler) + selector ! ReadInterest + + context.setReceiveTimeout(Duration.Undefined) + context.watch(handler) // sign death pact + + context.become(connected(handler)) + + case cmd: CloseCommand ⇒ + handleClose(commander, closeResponse(cmd)) + + case ReceiveTimeout ⇒ + // TODO: just shutting down, as we do here, presents a race condition to the user + // Should we introduce a dedicated `Registered` event message to notify the user of successful registration? + log.warning("Configured registration timeout of {} expired, stopping", registerTimeout) + context.stop(self) + } + + /** normal connected state */ + def connected(handler: ActorRef): Receive = { + case StopReading ⇒ selector ! StopReading + case ResumeReading ⇒ selector ! ReadInterest + case ChannelReadable ⇒ doRead(handler) + + case write: Write if writePending ⇒ + log.debug("Dropping write because queue is full") + handler ! CommandFailed(write) + + case write: Write ⇒ doWrite(handler, write) + case ChannelWritable ⇒ doWrite(handler, pendingWrite) + + case cmd: CloseCommand ⇒ handleClose(handler, closeResponse(cmd)) + } + + /** connection is closing but a write has to be finished first */ + def closingWithPendingWrite(handler: ActorRef, closedEvent: ConnectionClosed): Receive = { + case StopReading ⇒ selector ! StopReading + case ResumeReading ⇒ selector ! ReadInterest + case ChannelReadable ⇒ doRead(handler) + + case ChannelWritable ⇒ + doWrite(handler, pendingWrite) + if (!writePending) // writing is now finished + handleClose(handler, closedEvent) + + case Abort ⇒ handleClose(handler, Aborted) + } + + /** connection is closed on our side and we're waiting from confirmation from the other side */ + def closing(handler: ActorRef): Receive = { + case StopReading ⇒ selector ! StopReading + case ResumeReading ⇒ selector ! ReadInterest + case ChannelReadable ⇒ doRead(handler) + case Abort ⇒ handleClose(handler, Aborted) + } + + // AUXILIARIES and IMPLEMENTATION + + /** use in subclasses to start the common machinery above once a channel is connected */ + def completeConnect(commander: ActorRef, options: immutable.Seq[SocketOption]): Unit = { + options.foreach(_.afterConnect(channel.socket)) + + commander ! Connected( + channel.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress], + channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]) + + context.setReceiveTimeout(registerTimeout) + context.become(waitingForRegistration(commander)) + } + + def doRead(handler: ActorRef): Unit = { + val buffer = directBuffer() + + try { + log.debug("Trying to read from channel") + val readBytes = channel.read(buffer) + buffer.flip() + + if (readBytes > 0) { + log.debug("Read {} bytes", readBytes) + handler ! Received(ByteString(buffer).take(readBytes)) + if (readBytes == buffer.capacity()) + // directly try reading more because we exhausted our buffer + self ! ChannelReadable + else selector ! ReadInterest + } else if (readBytes == 0) { + log.debug("Read nothing. Registering read interest with selector") + selector ! ReadInterest + } else if (readBytes == -1) { + log.debug("Read returned end-of-stream") + doCloseConnection(handler, closeReason) + } else throw new IllegalStateException("Unexpected value returned from read: " + readBytes) + + } catch { + case e: IOException ⇒ handleError(handler, e) + } + } + + def doWrite(handler: ActorRef, write: Write): Unit = { + val data = write.data + + val buffer = directBuffer() + data.copyToBuffer(buffer) + buffer.flip() + + try { + log.debug("Trying to write to channel") + val writtenBytes = channel.write(buffer) + log.debug("Wrote {} bytes", writtenBytes) + pendingWrite = consume(write, writtenBytes) + + if (writePending) selector ! WriteInterest // still data to write + else if (write.ack != null) handler ! write.ack // everything written + } catch { + case e: IOException ⇒ handleError(handler, e) + } + } + + def closeReason = + if (channel.socket.isOutputShutdown) ConfirmedClosed + else PeerClosed + + def handleClose(handler: ActorRef, closedEvent: ConnectionClosed): Unit = + if (closedEvent == Aborted) { // close instantly + log.debug("Got Abort command. RESETing connection.") + doCloseConnection(handler, closedEvent) + + } else if (writePending) { // finish writing first + log.debug("Got Close command but write is still pending.") + context.become(closingWithPendingWrite(handler, closedEvent)) + + } else if (closedEvent == ConfirmedClosed) { // shutdown output and wait for confirmation + log.debug("Got ConfirmedClose command, sending FIN.") + channel.socket.shutdownOutput() + context.become(closing(handler)) + + } else { // close now + log.debug("Got Close command, closing connection.") + doCloseConnection(handler, closedEvent) + } + + def doCloseConnection(handler: ActorRef, closedEvent: ConnectionClosed): Unit = { + if (closedEvent == Aborted) abort() + else channel.close() + + handler ! closedEvent + context.stop(self) + } + + def closeResponse(closeCommand: CloseCommand): ConnectionClosed = + closeCommand match { + case Close ⇒ Closed + case Abort ⇒ Aborted + case ConfirmedClose ⇒ ConfirmedClosed + } + + def handleError(handler: ActorRef, exception: IOException): Unit = { + exception.setStackTrace(Array.empty) + handler ! ErrorClose(exception) + throw exception + } + + def abort(): Unit = { + try channel.socket.setSoLinger(true, 0) // causes the following close() to send TCP RST + catch { + case NonFatal(e) ⇒ + // setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574 + // (also affected: OS/X Java 1.6.0_37) + log.debug("setSoLinger(true, 0) failed with {}", e) + } + channel.close() + } + + override def postStop(): Unit = + if (channel.isOpen) + abort() + + /** Returns a new write with `numBytes` removed from the front */ + def consume(write: Write, numBytes: Int): Write = + write match { + case Write.Empty if numBytes == 0 ⇒ write + case _ ⇒ + numBytes match { + case 0 ⇒ write + case x if x == write.data.length ⇒ Write.Empty + case _ ⇒ + require(numBytes > 0 && numBytes < write.data.length) + write.copy(data = write.data.drop(numBytes)) + } + } +} diff --git a/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala b/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala new file mode 100644 index 0000000000..009621f032 --- /dev/null +++ b/akka-io/src/main/scala/akka/io/TcpIncomingConnection.scala @@ -0,0 +1,26 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ + +package akka.io + +import java.nio.channels.SocketChannel +import scala.collection.immutable +import akka.actor.ActorRef +import Tcp.SocketOption + +/** + * An actor handling the connection state machine for an incoming, already connected + * SocketChannel. + */ +class TcpIncomingConnection(_selector: ActorRef, + _channel: SocketChannel, + handler: ActorRef, + options: immutable.Seq[SocketOption]) extends TcpConnection(_selector, _channel) { + + context.watch(handler) // sign death pact + + completeConnect(handler, options) + + def receive = PartialFunction.empty +} diff --git a/akka-io/src/main/scala/akka/io/TcpManager.scala b/akka-io/src/main/scala/akka/io/TcpManager.scala index 344a861580..350a70eead 100644 --- a/akka-io/src/main/scala/akka/io/TcpManager.scala +++ b/akka-io/src/main/scala/akka/io/TcpManager.scala @@ -46,8 +46,8 @@ class TcpManager extends Actor { val selectorPool = context.actorOf(Props.empty.withRouter(RandomRouter(settings.NrOfSelectors))) def receive = { - case c: Connect ⇒ selectorPool forward c - case b: Bind ⇒ selectorPool forward b + case c: Connect ⇒ selectorPool forward c + case b: Bind ⇒ selectorPool forward b case Reject(command, commander) ⇒ commander ! CommandFailed(command) } } diff --git a/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala b/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala new file mode 100644 index 0000000000..6df56c696d --- /dev/null +++ b/akka-io/src/main/scala/akka/io/TcpOutgoingConnection.scala @@ -0,0 +1,51 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ + +package akka.io + +import java.net.InetSocketAddress +import java.io.IOException +import java.nio.channels.SocketChannel +import scala.collection.immutable +import akka.actor.ActorRef +import Tcp._ + +/** + * An actor handling the connection state machine for an outgoing connection + * to be established. + */ +class TcpOutgoingConnection(_selector: ActorRef, + commander: ActorRef, + remoteAddress: InetSocketAddress, + localAddress: Option[InetSocketAddress], + options: immutable.Seq[SocketOption]) + extends TcpConnection(_selector, SocketChannel.open()) { + context.watch(commander) // sign death pact + + localAddress.foreach(channel.socket.bind) + options.foreach(_.beforeConnect(channel.socket)) + + log.debug("Attempting connection to {}", remoteAddress) + if (channel.connect(remoteAddress)) + completeConnect(commander, options) + else { + selector ! RegisterClientChannel(channel) + context.become(connecting(commander, options)) + } + + def receive: Receive = PartialFunction.empty + + def connecting(commander: ActorRef, options: immutable.Seq[SocketOption]): Receive = { + case ChannelConnectable ⇒ + try { + val connected = channel.finishConnect() + assert(connected, "Connectable channel failed to connect") + log.debug("Connection established") + completeConnect(commander, options) + } catch { + case e: IOException ⇒ handleError(commander, e) + } + } + +} diff --git a/akka-io/src/main/scala/akka/io/ThreadLocalDirectBuffer.scala b/akka-io/src/main/scala/akka/io/ThreadLocalDirectBuffer.scala new file mode 100644 index 0000000000..ab0b38510c --- /dev/null +++ b/akka-io/src/main/scala/akka/io/ThreadLocalDirectBuffer.scala @@ -0,0 +1,32 @@ +/** + * Copyright (C) 2009-2013 Typesafe Inc. + */ + +package akka.io + +import java.nio.ByteBuffer +import akka.actor.Actor + +/** + * Allows an actor to get a thread local direct buffer of a size defined in the + * configuration of the actor system. An underlying assumption is that all of + * the threads which call `getDirectBuffer` are owned by the actor system. + */ +trait ThreadLocalDirectBuffer { _: Actor ⇒ + def directBuffer(): ByteBuffer = { + val result = ThreadLocalDirectBuffer.threadLocalBuffer.get() + if (result == null) { + val size = Tcp(context.system).Settings.DirectBufferSize + val newBuffer = ByteBuffer.allocateDirect(size) + ThreadLocalDirectBuffer.threadLocalBuffer.set(newBuffer) + newBuffer + } else { + result.clear() + result + } + } +} + +object ThreadLocalDirectBuffer { + private val threadLocalBuffer = new ThreadLocal[ByteBuffer] +} diff --git a/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala new file mode 100644 index 0000000000..2a5d3690f7 --- /dev/null +++ b/akka-io/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -0,0 +1,414 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ + +package akka.io + +import scala.annotation.tailrec + +import java.nio.channels.{ SelectionKey, SocketChannel, ServerSocketChannel } +import java.nio.ByteBuffer +import java.nio.channels.spi.SelectorProvider +import java.io.IOException +import java.net._ +import scala.collection.immutable +import scala.concurrent.duration._ +import scala.util.control.NonFatal +import akka.actor.{ ActorRef, Props, Actor, Terminated } +import akka.testkit.{ TestProbe, TestActorRef, AkkaSpec } +import akka.util.ByteString +import Tcp._ +import java.util.concurrent.CountDownLatch + +class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") { + val port = 45679 + val localhost = InetAddress.getLocalHost + val serverAddress = new InetSocketAddress(localhost, port) + + "An outgoing connection" must { + // common behavior + + "set socket options before connecting" in withLocalServer() { localServer ⇒ + val userHandler = TestProbe() + val selector = TestProbe() + val connectionActor = + createConnectionActor(selector.ref, userHandler.ref, options = Vector(SO.ReuseAddress(true))) + val clientChannel = connectionActor.underlyingActor.channel + clientChannel.socket.getReuseAddress must be(true) + } + + "set socket options after connecting" in withLocalServer() { localServer ⇒ + val userHandler = TestProbe() + val selector = TestProbe() + val connectionActor = + createConnectionActor(selector.ref, userHandler.ref, options = Vector(SO.KeepAlive(true))) + val clientChannel = connectionActor.underlyingActor.channel + clientChannel.socket.getKeepAlive must be(false) // only set after connection is established + selector.send(connectionActor, ChannelConnectable) + clientChannel.socket.getKeepAlive must be(true) + } + + "send incoming data to user" in withEstablishedConnection() { setup ⇒ + import setup._ + serverSideChannel.write(ByteBuffer.wrap("testdata".getBytes("ASCII"))) + // emulate selector behavior + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsgPF(remaining) { + case Received(data) if data.decodeString("ASCII") == "testdata" ⇒ + } + // have two packets in flight before the selector notices + serverSideChannel.write(ByteBuffer.wrap("testdata2".getBytes("ASCII"))) + serverSideChannel.write(ByteBuffer.wrap("testdata3".getBytes("ASCII"))) + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsgPF(remaining) { + case Received(data) if data.decodeString("ASCII") == "testdata2testdata3" ⇒ + } + } + + "write data to network (and acknowledge)" in withEstablishedConnection() { setup ⇒ + import setup._ + serverSideChannel.configureBlocking(false) + object Ack + val write = Write(ByteString("testdata"), Ack) + val buffer = ByteBuffer.allocate(100) + serverSideChannel.read(buffer) must be(0) + + // emulate selector behavior + connectionHandler.send(connectionActor, write) + connectionHandler.expectMsg(Ack) + serverSideChannel.read(buffer) must be(8) + buffer.flip() + ByteString(buffer).take(8).decodeString("ASCII") must be("testdata") + } + + "stop writing in cases of backpressure and resume afterwards" in + withEstablishedConnection(setSmallRcvBuffer) { setup ⇒ + import setup._ + object Ack1 + object Ack2 + + //serverSideChannel.configureBlocking(false) + clientSideChannel.socket.setSendBufferSize(1024) + + // producing backpressure by sending much more than currently fits into + // our send buffer + val firstWrite = writeCmd(Ack1) + + // try to write the buffer but since the SO_SNDBUF is too small + // it will have to keep the rest of the piece and send it + // when possible + connectionHandler.send(connectionActor, firstWrite) + selector.expectMsg(WriteInterest) + + // send another write which should fail immediately + // because we don't store more than one piece in flight + val secondWrite = writeCmd(Ack2) + connectionHandler.send(connectionActor, secondWrite) + connectionHandler.expectMsg(CommandFailed(secondWrite)) + + // there will be immediately more space in the send buffer because + // some data will have been sent by now, so we assume we can write + // again, but still it can't write everything + selector.send(connectionActor, ChannelWritable) + + // both buffers should now be filled so no more writing + // is possible + setup.pullFromServerSide(TestSize) + connectionHandler.expectMsg(Ack1) + } + + "respect StopReading and ResumeReading" in withEstablishedConnection() { setup ⇒ + import setup._ + connectionHandler.send(connectionActor, StopReading) + + // the selector interprets StopReading to deregister interest + // for reading + selector.expectMsg(StopReading) + connectionHandler.send(connectionActor, ResumeReading) + selector.expectMsg(ReadInterest) + } + + "close the connection" in withEstablishedConnection(setSmallRcvBuffer) { setup ⇒ + import setup._ + + // we should test here that a pending write command is properly finished first + object Ack + // set an artificially small send buffer size so that the write is queued + // inside the connection actor + clientSideChannel.socket.setSendBufferSize(1024) + + // we send a write and a close command directly afterwards + connectionHandler.send(connectionActor, writeCmd(Ack)) + connectionHandler.send(connectionActor, Close) + + setup.pullFromServerSide(TestSize) + connectionHandler.expectMsg(Ack) + connectionHandler.expectMsg(Closed) + connectionActor.isTerminated must be(true) + + val buffer = ByteBuffer.allocate(1) + serverSideChannel.read(buffer) must be(-1) + } + + "abort the connection" in withEstablishedConnection() { setup ⇒ + import setup._ + + connectionHandler.send(connectionActor, Abort) + connectionHandler.expectMsg(Aborted) + + assertThisConnectionActorTerminated() + + val buffer = ByteBuffer.allocate(1) + val thrown = evaluating { serverSideChannel.read(buffer) } must produce[IOException] + thrown.getMessage must be("Connection reset by peer") + } + + "close the connection and confirm" in withEstablishedConnection(setSmallRcvBuffer) { setup ⇒ + import setup._ + + // we should test here that a pending write command is properly finished first + object Ack + // set an artificially small send buffer size so that the write is queued + // inside the connection actor + clientSideChannel.socket.setSendBufferSize(1024) + + // we send a write and a close command directly afterwards + connectionHandler.send(connectionActor, writeCmd(Ack)) + connectionHandler.send(connectionActor, ConfirmedClose) + + connectionHandler.expectNoMsg(100.millis) + setup.pullFromServerSide(TestSize) + connectionHandler.expectMsg(Ack) + + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectNoMsg(100.millis) // not yet + + val buffer = ByteBuffer.allocate(1) + serverSideChannel.read(buffer) must be(-1) + serverSideChannel.close() + + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsg(ConfirmedClosed) + + assertThisConnectionActorTerminated() + } + + "report when peer closed the connection" in withEstablishedConnection() { setup ⇒ + import setup._ + + serverSideChannel.close() + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsg(PeerClosed) + + assertThisConnectionActorTerminated() + } + "report when peer aborted the connection" in withEstablishedConnection() { setup ⇒ + import setup._ + + abortClose(serverSideChannel) + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsgPF(remaining) { + case ErrorClose(exc: IOException) ⇒ exc.getMessage must be("Connection reset by peer") + } + // wait a while + connectionHandler.expectNoMsg(200.millis) + + assertThisConnectionActorTerminated() + } + "report when peer closed the connection when trying to write" in withEstablishedConnection() { setup ⇒ + import setup._ + + abortClose(serverSideChannel) + connectionHandler.send(connectionActor, Write(ByteString("testdata"))) + connectionHandler.expectMsgPF(remaining) { + case ErrorClose(_: IOException) ⇒ // ok + } + + assertThisConnectionActorTerminated() + } + + // error conditions + "report failed connection attempt while not registered" in withLocalServer() { localServer ⇒ + val userHandler = TestProbe() + val selector = TestProbe() + val connectionActor = createConnectionActor(selector.ref, userHandler.ref) + val clientSideChannel = connectionActor.underlyingActor.channel + selector.expectMsg(RegisterClientChannel(clientSideChannel)) + + // close instead of accept + localServer.close() + selector.send(connectionActor, ChannelConnectable) + userHandler.expectMsgPF() { + case ErrorClose(e) ⇒ e.getMessage must be("Connection reset by peer") + } + + assertActorTerminated(connectionActor) + } + + "report failed connection attempt when target is unreachable" in { + val userHandler = TestProbe() + val selector = TestProbe() + val connectionActor = createConnectionActor(selector.ref, userHandler.ref, serverAddress = new InetSocketAddress("127.0.0.1", 63186)) + val clientSideChannel = connectionActor.underlyingActor.channel + selector.expectMsg(RegisterClientChannel(clientSideChannel)) + val sel = SelectorProvider.provider().openSelector() + val key = clientSideChannel.register(sel, SelectionKey.OP_CONNECT | SelectionKey.OP_READ) + sel.select(200) + + key.isConnectable must be(true) + selector.send(connectionActor, ChannelConnectable) + userHandler.expectMsgPF() { + case ErrorClose(e) ⇒ e.getMessage must be("Connection refused") + } + + assertActorTerminated(connectionActor) + } + + "time out when Connected isn't answered with Register" in withLocalServer() { localServer ⇒ + val userHandler = TestProbe() + val selector = TestProbe() + val connectionActor = createConnectionActor(selector.ref, userHandler.ref) + val clientSideChannel = connectionActor.underlyingActor.channel + selector.expectMsg(RegisterClientChannel(clientSideChannel)) + localServer.accept() + selector.send(connectionActor, ChannelConnectable) + userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress])) + + assertActorTerminated(connectionActor) + } + + "close the connection when user handler dies while connecting" in withLocalServer() { localServer ⇒ + val userHandler = system.actorOf(Props(new Actor { + def receive = PartialFunction.empty + })) + val selector = TestProbe() + val connectionActor = createConnectionActor(selector.ref, userHandler) + val clientSideChannel = connectionActor.underlyingActor.channel + selector.expectMsg(RegisterClientChannel(clientSideChannel)) + system.stop(userHandler) + assertActorTerminated(connectionActor) + } + + "close the connection when connection handler dies while connected" in withEstablishedConnection() { setup ⇒ + import setup._ + watch(connectionHandler.ref) + watch(connectionActor) + system.stop(connectionHandler.ref) + expectMsgType[Terminated].actor must be(connectionHandler.ref) + expectMsgType[Terminated].actor must be(connectionActor) + } + } + + def withLocalServer(setServerSocketOptions: ServerSocketChannel ⇒ Unit = _ ⇒ ())(body: ServerSocketChannel ⇒ Any): Unit = { + val localServer = ServerSocketChannel.open() + try { + setServerSocketOptions(localServer) + localServer.socket.bind(serverAddress) + localServer.configureBlocking(false) + body(localServer) + } finally localServer.close() + } + + case class Setup( + userHandler: TestProbe, + connectionHandler: TestProbe, + selector: TestProbe, + connectionActor: TestActorRef[TcpOutgoingConnection], + clientSideChannel: SocketChannel, + serverSideChannel: SocketChannel) { + + val buffer = ByteBuffer.allocate(TestSize) + @tailrec final def pullFromServerSide(remaining: Int): Unit = + if (remaining > 0) { + if (selector.msgAvailable) { + selector.expectMsg(WriteInterest) + selector.send(connectionActor, ChannelWritable) + } + buffer.clear() + val read = serverSideChannel.read(buffer) + if (read == 0) + throw new IllegalStateException("Didn't make any progress") + else if (read == -1) + throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining) + + pullFromServerSide(remaining - read) + } + + def assertThisConnectionActorTerminated(): Unit = { + assertActorTerminated(connectionActor) + clientSideChannel must not be ('open) + } + } + def withEstablishedConnection(setServerSocketOptions: ServerSocketChannel ⇒ Unit = _ ⇒ ())(body: Setup ⇒ Any): Unit = withLocalServer(setServerSocketOptions) { localServer ⇒ + val userHandler = TestProbe() + val connectionHandler = TestProbe() + val selector = TestProbe() + val connectionActor = createConnectionActor(selector.ref, userHandler.ref) + val clientSideChannel = connectionActor.underlyingActor.channel + + selector.expectMsg(RegisterClientChannel(clientSideChannel)) + + localServer.configureBlocking(true) + val serverSideChannel = localServer.accept() + + serverSideChannel must not be (null) + selector.send(connectionActor, ChannelConnectable) + userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress])) + userHandler.send(connectionActor, Register(connectionHandler.ref)) + selector.expectMsg(ReadInterest) + + body { + Setup( + userHandler, + connectionHandler, + selector, + connectionActor, + clientSideChannel, + serverSideChannel) + } + } + + val TestSize = 10000 + + def writeCmd(ack: AnyRef) = + Write(ByteString(Array.fill[Byte](TestSize)(0)), ack) + + def setSmallRcvBuffer(channel: ServerSocketChannel): Unit = + channel.socket.setReceiveBufferSize(1024) + + def createConnectionActor( + selector: ActorRef, + commander: ActorRef, + serverAddress: InetSocketAddress = serverAddress, + localAddress: Option[InetSocketAddress] = None, + options: immutable.Seq[Tcp.SocketOption] = Nil): TestActorRef[TcpOutgoingConnection] = { + + TestActorRef( + new TcpOutgoingConnection(selector, commander, serverAddress, localAddress, options) { + override def postRestart(reason: Throwable) { + // ensure we never restart + context.stop(self) + } + }) + } + + def abortClose(channel: SocketChannel): Unit = { + try channel.socket.setSoLinger(true, 0) // causes the following close() to send TCP RST + catch { + case NonFatal(e) ⇒ + // setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574 + // (also affected: OS/X Java 1.6.0_37) + log.debug("setSoLinger(true, 0) failed with {}", e) + } + channel.close() + } + + def abort(channel: SocketChannel) { + channel.socket.setSoLinger(true, 0) + channel.close() + } + def assertActorTerminated(connectionActor: TestActorRef[TcpOutgoingConnection]): Unit = { + watch(connectionActor) + expectMsgType[Terminated].actor must be(connectionActor) + } +}