tcp connection actors, see #2886

This commit is contained in:
Johannes Rudolph 2013-01-15 18:08:45 +01:00
parent 284e64f7c5
commit be9abae1e3
8 changed files with 804 additions and 46 deletions

View file

@ -54,8 +54,17 @@ akka {
# Fully qualified config path which holds the dispatcher configuration
# for the selector management actors
management-dispatcher = "akka.actor.default-dispatcher"
# The size of the thread-local direct buffers used to read or write
# network data from the kernel. Those buffer directly add to the footprint
# of the threads from the dispatcher tcp connection actors are using.
direct-buffer-size = 524288
# The duration a connection actor waits for a `Register` message from
# its commander before aborting the connection.
register-timeout = 5s
}
}
}
}

View file

@ -31,31 +31,30 @@ object Tcp extends ExtensionKey[TcpExt] {
case class Bind(handler: ActorRef,
address: InetSocketAddress,
backlog: Int = 100,
options: immutable.Seq[SO.SocketOption] = Nil) extends Command
options: immutable.Seq[SocketOption] = Nil) extends Command
case object Unbind extends Command
case class Register(handler: ActorRef) extends Command
object SO {
/**
* SocketOption is a package of data (from the user) and associated
* behavior (how to apply that to a socket).
*/
sealed trait SocketOption {
/**
* SocketOption is a package of data (from the user) and associated
* behavior (how to apply that to a socket).
* Action to be taken for this option before calling bind()
*/
sealed trait SocketOption {
/**
* Action to be taken for this option before calling bind()
*/
def beforeBind(s: ServerSocket): Unit = ()
/**
* Action to be taken for this option before calling connect()
*/
def beforeConnect(s: Socket): Unit = ()
/**
* Action to be taken for this option after connect returned (i.e. on
* the slave socket for servers).
*/
def afterConnect(s: Socket): Unit = ()
}
def beforeBind(s: ServerSocket): Unit = ()
/**
* Action to be taken for this option before calling connect()
*/
def beforeConnect(s: Socket): Unit = ()
/**
* Action to be taken for this option after connect returned (i.e. on
* the slave socket for servers).
*/
def afterConnect(s: Socket): Unit = ()
}
object SO {
// shared socket options
/**
@ -109,7 +108,7 @@ object Tcp extends ExtensionKey[TcpExt] {
* For more information see [[java.net.Socket.setSendBufferSize]]
*/
case class SendBufferSize(size: Int) extends SocketOption {
require(size > 0, "ReceiveBufferSize must be > 0")
require(size > 0, "SendBufferSize must be > 0")
override def afterConnect(s: Socket): Unit = s.setSendBufferSize(size)
}
@ -139,39 +138,37 @@ object Tcp extends ExtensionKey[TcpExt] {
}
// TODO: what about close reasons?
case object Close extends Command
case object ConfirmedClose extends Command
case object Abort extends Command
sealed trait CloseCommand extends Command
trait Write extends Command {
def data: ByteString
def ack: AnyRef
def nack: AnyRef
}
case object Close extends CloseCommand
case object ConfirmedClose extends CloseCommand
case object Abort extends CloseCommand
case class Write(data: ByteString, ack: AnyRef) extends Command
object Write {
def apply(_data: ByteString): Write = new Write {
def data: ByteString = _data
def ack: AnyRef = null
def nack: AnyRef = null
}
val Empty: Write = Write(ByteString.empty, null)
def apply(data: ByteString): Write =
if (data.isEmpty) Empty else Write(data, null)
}
case object StopReading extends Command
case object ResumeReading extends Command
/// EVENTS
sealed trait Event
case object Bound extends Event
case class Received(data: ByteString) extends Event
case class Connected(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) extends Event
case class Connected(remoteAddress: InetSocketAddress, localAddress: InetSocketAddress) extends Event
case class CommandFailed(cmd: Command) extends Event
case object Bound extends Event
case object Unbound extends Event
sealed trait Closed extends Event
case object PeerClosed extends Closed
case object ActivelyClosed extends Closed
case object ConfirmedClosed extends Closed
case class Error(cause: Throwable) extends Closed
sealed trait ConnectionClosed extends Event
case object Closed extends ConnectionClosed
case object Aborted extends ConnectionClosed
case object ConfirmedClosed extends ConnectionClosed
case object PeerClosed extends ConnectionClosed
case class ErrorClose(cause: Throwable) extends ConnectionClosed
/// INTERNAL
case class RegisterClientChannel(channel: SocketChannel)
@ -208,9 +205,13 @@ class TcpExt(system: ExtendedActorSystem) extends IO.Extension {
val SelectorDispatcher = getString("selector-dispatcher")
val WorkerDispatcher = getString("worker-dispatcher")
val ManagementDispatcher = getString("management-dispatcher")
val DirectBufferSize = getInt("direct-buffer-size")
val RegisterTimeout =
if (getString("register-timeout") == "infinite") Duration.Undefined
else Duration(getMilliseconds("register-timeout"), MILLISECONDS)
}
val manager = system.asInstanceOf[ActorSystemImpl].systemActorOf(
Props[TcpManager].withDispatcher(Settings.ManagementDispatcher), "IO-TCP")
}
}

View file

@ -0,0 +1,225 @@
/**
* Copyright (C) 2009-2013 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.io
import java.net.InetSocketAddress
import java.io.IOException
import java.nio.channels.SocketChannel
import scala.util.control.NonFatal
import scala.collection.immutable
import scala.concurrent.duration._
import akka.actor._
import akka.util.ByteString
import Tcp._
/**
* Base class for TcpIncomingConnection and TcpOutgoingConnection.
*/
abstract class TcpConnection(val selector: ActorRef,
val channel: SocketChannel) extends Actor with ThreadLocalDirectBuffer with ActorLogging {
channel.configureBlocking(false)
var pendingWrite: Write = Write.Empty // a write "queue" of size 1 for holding one unfinished write command
def writePending = pendingWrite ne Write.Empty
def registerTimeout = Tcp(context.system).Settings.RegisterTimeout
// STATES
/** connection established, waiting for registration from user handler */
def waitingForRegistration(commander: ActorRef): Receive = {
case Register(handler)
log.debug("{} registered as connection handler", handler)
selector ! ReadInterest
context.setReceiveTimeout(Duration.Undefined)
context.watch(handler) // sign death pact
context.become(connected(handler))
case cmd: CloseCommand
handleClose(commander, closeResponse(cmd))
case ReceiveTimeout
// TODO: just shutting down, as we do here, presents a race condition to the user
// Should we introduce a dedicated `Registered` event message to notify the user of successful registration?
log.warning("Configured registration timeout of {} expired, stopping", registerTimeout)
context.stop(self)
}
/** normal connected state */
def connected(handler: ActorRef): Receive = {
case StopReading selector ! StopReading
case ResumeReading selector ! ReadInterest
case ChannelReadable doRead(handler)
case write: Write if writePending
log.debug("Dropping write because queue is full")
handler ! CommandFailed(write)
case write: Write doWrite(handler, write)
case ChannelWritable doWrite(handler, pendingWrite)
case cmd: CloseCommand handleClose(handler, closeResponse(cmd))
}
/** connection is closing but a write has to be finished first */
def closingWithPendingWrite(handler: ActorRef, closedEvent: ConnectionClosed): Receive = {
case StopReading selector ! StopReading
case ResumeReading selector ! ReadInterest
case ChannelReadable doRead(handler)
case ChannelWritable
doWrite(handler, pendingWrite)
if (!writePending) // writing is now finished
handleClose(handler, closedEvent)
case Abort handleClose(handler, Aborted)
}
/** connection is closed on our side and we're waiting from confirmation from the other side */
def closing(handler: ActorRef): Receive = {
case StopReading selector ! StopReading
case ResumeReading selector ! ReadInterest
case ChannelReadable doRead(handler)
case Abort handleClose(handler, Aborted)
}
// AUXILIARIES and IMPLEMENTATION
/** use in subclasses to start the common machinery above once a channel is connected */
def completeConnect(commander: ActorRef, options: immutable.Seq[SocketOption]): Unit = {
options.foreach(_.afterConnect(channel.socket))
commander ! Connected(
channel.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress],
channel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress])
context.setReceiveTimeout(registerTimeout)
context.become(waitingForRegistration(commander))
}
def doRead(handler: ActorRef): Unit = {
val buffer = directBuffer()
try {
log.debug("Trying to read from channel")
val readBytes = channel.read(buffer)
buffer.flip()
if (readBytes > 0) {
log.debug("Read {} bytes", readBytes)
handler ! Received(ByteString(buffer).take(readBytes))
if (readBytes == buffer.capacity())
// directly try reading more because we exhausted our buffer
self ! ChannelReadable
else selector ! ReadInterest
} else if (readBytes == 0) {
log.debug("Read nothing. Registering read interest with selector")
selector ! ReadInterest
} else if (readBytes == -1) {
log.debug("Read returned end-of-stream")
doCloseConnection(handler, closeReason)
} else throw new IllegalStateException("Unexpected value returned from read: " + readBytes)
} catch {
case e: IOException handleError(handler, e)
}
}
def doWrite(handler: ActorRef, write: Write): Unit = {
val data = write.data
val buffer = directBuffer()
data.copyToBuffer(buffer)
buffer.flip()
try {
log.debug("Trying to write to channel")
val writtenBytes = channel.write(buffer)
log.debug("Wrote {} bytes", writtenBytes)
pendingWrite = consume(write, writtenBytes)
if (writePending) selector ! WriteInterest // still data to write
else if (write.ack != null) handler ! write.ack // everything written
} catch {
case e: IOException handleError(handler, e)
}
}
def closeReason =
if (channel.socket.isOutputShutdown) ConfirmedClosed
else PeerClosed
def handleClose(handler: ActorRef, closedEvent: ConnectionClosed): Unit =
if (closedEvent == Aborted) { // close instantly
log.debug("Got Abort command. RESETing connection.")
doCloseConnection(handler, closedEvent)
} else if (writePending) { // finish writing first
log.debug("Got Close command but write is still pending.")
context.become(closingWithPendingWrite(handler, closedEvent))
} else if (closedEvent == ConfirmedClosed) { // shutdown output and wait for confirmation
log.debug("Got ConfirmedClose command, sending FIN.")
channel.socket.shutdownOutput()
context.become(closing(handler))
} else { // close now
log.debug("Got Close command, closing connection.")
doCloseConnection(handler, closedEvent)
}
def doCloseConnection(handler: ActorRef, closedEvent: ConnectionClosed): Unit = {
if (closedEvent == Aborted) abort()
else channel.close()
handler ! closedEvent
context.stop(self)
}
def closeResponse(closeCommand: CloseCommand): ConnectionClosed =
closeCommand match {
case Close Closed
case Abort Aborted
case ConfirmedClose ConfirmedClosed
}
def handleError(handler: ActorRef, exception: IOException): Unit = {
exception.setStackTrace(Array.empty)
handler ! ErrorClose(exception)
throw exception
}
def abort(): Unit = {
try channel.socket.setSoLinger(true, 0) // causes the following close() to send TCP RST
catch {
case NonFatal(e)
// setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574
// (also affected: OS/X Java 1.6.0_37)
log.debug("setSoLinger(true, 0) failed with {}", e)
}
channel.close()
}
override def postStop(): Unit =
if (channel.isOpen)
abort()
/** Returns a new write with `numBytes` removed from the front */
def consume(write: Write, numBytes: Int): Write =
write match {
case Write.Empty if numBytes == 0 write
case _
numBytes match {
case 0 write
case x if x == write.data.length Write.Empty
case _
require(numBytes > 0 && numBytes < write.data.length)
write.copy(data = write.data.drop(numBytes))
}
}
}

View file

@ -0,0 +1,26 @@
/**
* Copyright (C) 2009-2013 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.io
import java.nio.channels.SocketChannel
import scala.collection.immutable
import akka.actor.ActorRef
import Tcp.SocketOption
/**
* An actor handling the connection state machine for an incoming, already connected
* SocketChannel.
*/
class TcpIncomingConnection(_selector: ActorRef,
_channel: SocketChannel,
handler: ActorRef,
options: immutable.Seq[SocketOption]) extends TcpConnection(_selector, _channel) {
context.watch(handler) // sign death pact
completeConnect(handler, options)
def receive = PartialFunction.empty
}

View file

@ -46,8 +46,8 @@ class TcpManager extends Actor {
val selectorPool = context.actorOf(Props.empty.withRouter(RandomRouter(settings.NrOfSelectors)))
def receive = {
case c: Connect selectorPool forward c
case b: Bind selectorPool forward b
case c: Connect selectorPool forward c
case b: Bind selectorPool forward b
case Reject(command, commander) commander ! CommandFailed(command)
}
}

View file

@ -0,0 +1,51 @@
/**
* Copyright (C) 2009-2013 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.io
import java.net.InetSocketAddress
import java.io.IOException
import java.nio.channels.SocketChannel
import scala.collection.immutable
import akka.actor.ActorRef
import Tcp._
/**
* An actor handling the connection state machine for an outgoing connection
* to be established.
*/
class TcpOutgoingConnection(_selector: ActorRef,
commander: ActorRef,
remoteAddress: InetSocketAddress,
localAddress: Option[InetSocketAddress],
options: immutable.Seq[SocketOption])
extends TcpConnection(_selector, SocketChannel.open()) {
context.watch(commander) // sign death pact
localAddress.foreach(channel.socket.bind)
options.foreach(_.beforeConnect(channel.socket))
log.debug("Attempting connection to {}", remoteAddress)
if (channel.connect(remoteAddress))
completeConnect(commander, options)
else {
selector ! RegisterClientChannel(channel)
context.become(connecting(commander, options))
}
def receive: Receive = PartialFunction.empty
def connecting(commander: ActorRef, options: immutable.Seq[SocketOption]): Receive = {
case ChannelConnectable
try {
val connected = channel.finishConnect()
assert(connected, "Connectable channel failed to connect")
log.debug("Connection established")
completeConnect(commander, options)
} catch {
case e: IOException handleError(commander, e)
}
}
}

View file

@ -0,0 +1,32 @@
/**
* Copyright (C) 2009-2013 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.io
import java.nio.ByteBuffer
import akka.actor.Actor
/**
* Allows an actor to get a thread local direct buffer of a size defined in the
* configuration of the actor system. An underlying assumption is that all of
* the threads which call `getDirectBuffer` are owned by the actor system.
*/
trait ThreadLocalDirectBuffer { _: Actor
def directBuffer(): ByteBuffer = {
val result = ThreadLocalDirectBuffer.threadLocalBuffer.get()
if (result == null) {
val size = Tcp(context.system).Settings.DirectBufferSize
val newBuffer = ByteBuffer.allocateDirect(size)
ThreadLocalDirectBuffer.threadLocalBuffer.set(newBuffer)
newBuffer
} else {
result.clear()
result
}
}
}
object ThreadLocalDirectBuffer {
private val threadLocalBuffer = new ThreadLocal[ByteBuffer]
}

View file

@ -0,0 +1,414 @@
/**
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.io
import scala.annotation.tailrec
import java.nio.channels.{ SelectionKey, SocketChannel, ServerSocketChannel }
import java.nio.ByteBuffer
import java.nio.channels.spi.SelectorProvider
import java.io.IOException
import java.net._
import scala.collection.immutable
import scala.concurrent.duration._
import scala.util.control.NonFatal
import akka.actor.{ ActorRef, Props, Actor, Terminated }
import akka.testkit.{ TestProbe, TestActorRef, AkkaSpec }
import akka.util.ByteString
import Tcp._
import java.util.concurrent.CountDownLatch
class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") {
val port = 45679
val localhost = InetAddress.getLocalHost
val serverAddress = new InetSocketAddress(localhost, port)
"An outgoing connection" must {
// common behavior
"set socket options before connecting" in withLocalServer() { localServer
val userHandler = TestProbe()
val selector = TestProbe()
val connectionActor =
createConnectionActor(selector.ref, userHandler.ref, options = Vector(SO.ReuseAddress(true)))
val clientChannel = connectionActor.underlyingActor.channel
clientChannel.socket.getReuseAddress must be(true)
}
"set socket options after connecting" in withLocalServer() { localServer
val userHandler = TestProbe()
val selector = TestProbe()
val connectionActor =
createConnectionActor(selector.ref, userHandler.ref, options = Vector(SO.KeepAlive(true)))
val clientChannel = connectionActor.underlyingActor.channel
clientChannel.socket.getKeepAlive must be(false) // only set after connection is established
selector.send(connectionActor, ChannelConnectable)
clientChannel.socket.getKeepAlive must be(true)
}
"send incoming data to user" in withEstablishedConnection() { setup
import setup._
serverSideChannel.write(ByteBuffer.wrap("testdata".getBytes("ASCII")))
// emulate selector behavior
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectMsgPF(remaining) {
case Received(data) if data.decodeString("ASCII") == "testdata"
}
// have two packets in flight before the selector notices
serverSideChannel.write(ByteBuffer.wrap("testdata2".getBytes("ASCII")))
serverSideChannel.write(ByteBuffer.wrap("testdata3".getBytes("ASCII")))
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectMsgPF(remaining) {
case Received(data) if data.decodeString("ASCII") == "testdata2testdata3"
}
}
"write data to network (and acknowledge)" in withEstablishedConnection() { setup
import setup._
serverSideChannel.configureBlocking(false)
object Ack
val write = Write(ByteString("testdata"), Ack)
val buffer = ByteBuffer.allocate(100)
serverSideChannel.read(buffer) must be(0)
// emulate selector behavior
connectionHandler.send(connectionActor, write)
connectionHandler.expectMsg(Ack)
serverSideChannel.read(buffer) must be(8)
buffer.flip()
ByteString(buffer).take(8).decodeString("ASCII") must be("testdata")
}
"stop writing in cases of backpressure and resume afterwards" in
withEstablishedConnection(setSmallRcvBuffer) { setup
import setup._
object Ack1
object Ack2
//serverSideChannel.configureBlocking(false)
clientSideChannel.socket.setSendBufferSize(1024)
// producing backpressure by sending much more than currently fits into
// our send buffer
val firstWrite = writeCmd(Ack1)
// try to write the buffer but since the SO_SNDBUF is too small
// it will have to keep the rest of the piece and send it
// when possible
connectionHandler.send(connectionActor, firstWrite)
selector.expectMsg(WriteInterest)
// send another write which should fail immediately
// because we don't store more than one piece in flight
val secondWrite = writeCmd(Ack2)
connectionHandler.send(connectionActor, secondWrite)
connectionHandler.expectMsg(CommandFailed(secondWrite))
// there will be immediately more space in the send buffer because
// some data will have been sent by now, so we assume we can write
// again, but still it can't write everything
selector.send(connectionActor, ChannelWritable)
// both buffers should now be filled so no more writing
// is possible
setup.pullFromServerSide(TestSize)
connectionHandler.expectMsg(Ack1)
}
"respect StopReading and ResumeReading" in withEstablishedConnection() { setup
import setup._
connectionHandler.send(connectionActor, StopReading)
// the selector interprets StopReading to deregister interest
// for reading
selector.expectMsg(StopReading)
connectionHandler.send(connectionActor, ResumeReading)
selector.expectMsg(ReadInterest)
}
"close the connection" in withEstablishedConnection(setSmallRcvBuffer) { setup
import setup._
// we should test here that a pending write command is properly finished first
object Ack
// set an artificially small send buffer size so that the write is queued
// inside the connection actor
clientSideChannel.socket.setSendBufferSize(1024)
// we send a write and a close command directly afterwards
connectionHandler.send(connectionActor, writeCmd(Ack))
connectionHandler.send(connectionActor, Close)
setup.pullFromServerSide(TestSize)
connectionHandler.expectMsg(Ack)
connectionHandler.expectMsg(Closed)
connectionActor.isTerminated must be(true)
val buffer = ByteBuffer.allocate(1)
serverSideChannel.read(buffer) must be(-1)
}
"abort the connection" in withEstablishedConnection() { setup
import setup._
connectionHandler.send(connectionActor, Abort)
connectionHandler.expectMsg(Aborted)
assertThisConnectionActorTerminated()
val buffer = ByteBuffer.allocate(1)
val thrown = evaluating { serverSideChannel.read(buffer) } must produce[IOException]
thrown.getMessage must be("Connection reset by peer")
}
"close the connection and confirm" in withEstablishedConnection(setSmallRcvBuffer) { setup
import setup._
// we should test here that a pending write command is properly finished first
object Ack
// set an artificially small send buffer size so that the write is queued
// inside the connection actor
clientSideChannel.socket.setSendBufferSize(1024)
// we send a write and a close command directly afterwards
connectionHandler.send(connectionActor, writeCmd(Ack))
connectionHandler.send(connectionActor, ConfirmedClose)
connectionHandler.expectNoMsg(100.millis)
setup.pullFromServerSide(TestSize)
connectionHandler.expectMsg(Ack)
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectNoMsg(100.millis) // not yet
val buffer = ByteBuffer.allocate(1)
serverSideChannel.read(buffer) must be(-1)
serverSideChannel.close()
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectMsg(ConfirmedClosed)
assertThisConnectionActorTerminated()
}
"report when peer closed the connection" in withEstablishedConnection() { setup
import setup._
serverSideChannel.close()
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectMsg(PeerClosed)
assertThisConnectionActorTerminated()
}
"report when peer aborted the connection" in withEstablishedConnection() { setup
import setup._
abortClose(serverSideChannel)
selector.send(connectionActor, ChannelReadable)
connectionHandler.expectMsgPF(remaining) {
case ErrorClose(exc: IOException) exc.getMessage must be("Connection reset by peer")
}
// wait a while
connectionHandler.expectNoMsg(200.millis)
assertThisConnectionActorTerminated()
}
"report when peer closed the connection when trying to write" in withEstablishedConnection() { setup
import setup._
abortClose(serverSideChannel)
connectionHandler.send(connectionActor, Write(ByteString("testdata")))
connectionHandler.expectMsgPF(remaining) {
case ErrorClose(_: IOException) // ok
}
assertThisConnectionActorTerminated()
}
// error conditions
"report failed connection attempt while not registered" in withLocalServer() { localServer
val userHandler = TestProbe()
val selector = TestProbe()
val connectionActor = createConnectionActor(selector.ref, userHandler.ref)
val clientSideChannel = connectionActor.underlyingActor.channel
selector.expectMsg(RegisterClientChannel(clientSideChannel))
// close instead of accept
localServer.close()
selector.send(connectionActor, ChannelConnectable)
userHandler.expectMsgPF() {
case ErrorClose(e) e.getMessage must be("Connection reset by peer")
}
assertActorTerminated(connectionActor)
}
"report failed connection attempt when target is unreachable" in {
val userHandler = TestProbe()
val selector = TestProbe()
val connectionActor = createConnectionActor(selector.ref, userHandler.ref, serverAddress = new InetSocketAddress("127.0.0.1", 63186))
val clientSideChannel = connectionActor.underlyingActor.channel
selector.expectMsg(RegisterClientChannel(clientSideChannel))
val sel = SelectorProvider.provider().openSelector()
val key = clientSideChannel.register(sel, SelectionKey.OP_CONNECT | SelectionKey.OP_READ)
sel.select(200)
key.isConnectable must be(true)
selector.send(connectionActor, ChannelConnectable)
userHandler.expectMsgPF() {
case ErrorClose(e) e.getMessage must be("Connection refused")
}
assertActorTerminated(connectionActor)
}
"time out when Connected isn't answered with Register" in withLocalServer() { localServer
val userHandler = TestProbe()
val selector = TestProbe()
val connectionActor = createConnectionActor(selector.ref, userHandler.ref)
val clientSideChannel = connectionActor.underlyingActor.channel
selector.expectMsg(RegisterClientChannel(clientSideChannel))
localServer.accept()
selector.send(connectionActor, ChannelConnectable)
userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]))
assertActorTerminated(connectionActor)
}
"close the connection when user handler dies while connecting" in withLocalServer() { localServer
val userHandler = system.actorOf(Props(new Actor {
def receive = PartialFunction.empty
}))
val selector = TestProbe()
val connectionActor = createConnectionActor(selector.ref, userHandler)
val clientSideChannel = connectionActor.underlyingActor.channel
selector.expectMsg(RegisterClientChannel(clientSideChannel))
system.stop(userHandler)
assertActorTerminated(connectionActor)
}
"close the connection when connection handler dies while connected" in withEstablishedConnection() { setup
import setup._
watch(connectionHandler.ref)
watch(connectionActor)
system.stop(connectionHandler.ref)
expectMsgType[Terminated].actor must be(connectionHandler.ref)
expectMsgType[Terminated].actor must be(connectionActor)
}
}
def withLocalServer(setServerSocketOptions: ServerSocketChannel Unit = _ ())(body: ServerSocketChannel Any): Unit = {
val localServer = ServerSocketChannel.open()
try {
setServerSocketOptions(localServer)
localServer.socket.bind(serverAddress)
localServer.configureBlocking(false)
body(localServer)
} finally localServer.close()
}
case class Setup(
userHandler: TestProbe,
connectionHandler: TestProbe,
selector: TestProbe,
connectionActor: TestActorRef[TcpOutgoingConnection],
clientSideChannel: SocketChannel,
serverSideChannel: SocketChannel) {
val buffer = ByteBuffer.allocate(TestSize)
@tailrec final def pullFromServerSide(remaining: Int): Unit =
if (remaining > 0) {
if (selector.msgAvailable) {
selector.expectMsg(WriteInterest)
selector.send(connectionActor, ChannelWritable)
}
buffer.clear()
val read = serverSideChannel.read(buffer)
if (read == 0)
throw new IllegalStateException("Didn't make any progress")
else if (read == -1)
throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining)
pullFromServerSide(remaining - read)
}
def assertThisConnectionActorTerminated(): Unit = {
assertActorTerminated(connectionActor)
clientSideChannel must not be ('open)
}
}
def withEstablishedConnection(setServerSocketOptions: ServerSocketChannel Unit = _ ())(body: Setup Any): Unit = withLocalServer(setServerSocketOptions) { localServer
val userHandler = TestProbe()
val connectionHandler = TestProbe()
val selector = TestProbe()
val connectionActor = createConnectionActor(selector.ref, userHandler.ref)
val clientSideChannel = connectionActor.underlyingActor.channel
selector.expectMsg(RegisterClientChannel(clientSideChannel))
localServer.configureBlocking(true)
val serverSideChannel = localServer.accept()
serverSideChannel must not be (null)
selector.send(connectionActor, ChannelConnectable)
userHandler.expectMsg(Connected(serverAddress, clientSideChannel.socket.getLocalSocketAddress.asInstanceOf[InetSocketAddress]))
userHandler.send(connectionActor, Register(connectionHandler.ref))
selector.expectMsg(ReadInterest)
body {
Setup(
userHandler,
connectionHandler,
selector,
connectionActor,
clientSideChannel,
serverSideChannel)
}
}
val TestSize = 10000
def writeCmd(ack: AnyRef) =
Write(ByteString(Array.fill[Byte](TestSize)(0)), ack)
def setSmallRcvBuffer(channel: ServerSocketChannel): Unit =
channel.socket.setReceiveBufferSize(1024)
def createConnectionActor(
selector: ActorRef,
commander: ActorRef,
serverAddress: InetSocketAddress = serverAddress,
localAddress: Option[InetSocketAddress] = None,
options: immutable.Seq[Tcp.SocketOption] = Nil): TestActorRef[TcpOutgoingConnection] = {
TestActorRef(
new TcpOutgoingConnection(selector, commander, serverAddress, localAddress, options) {
override def postRestart(reason: Throwable) {
// ensure we never restart
context.stop(self)
}
})
}
def abortClose(channel: SocketChannel): Unit = {
try channel.socket.setSoLinger(true, 0) // causes the following close() to send TCP RST
catch {
case NonFatal(e)
// setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574
// (also affected: OS/X Java 1.6.0_37)
log.debug("setSoLinger(true, 0) failed with {}", e)
}
channel.close()
}
def abort(channel: SocketChannel) {
channel.socket.setSoLinger(true, 0)
channel.close()
}
def assertActorTerminated(connectionActor: TestActorRef[TcpOutgoingConnection]): Unit = {
watch(connectionActor)
expectMsgType[Terminated].actor must be(connectionActor)
}
}