fix issues discussed in the pull request

This commit is contained in:
Johannes Rudolph 2013-01-17 14:45:50 +01:00
parent e11c3fe6bb
commit 18aecef4bd
4 changed files with 131 additions and 73 deletions

View file

@ -13,26 +13,35 @@ import scala.concurrent.duration._
import akka.actor._
import akka.util.ByteString
import Tcp._
import annotation.tailrec
/**
* Base class for TcpIncomingConnection and TcpOutgoingConnection.
*/
abstract class TcpConnection(val selector: ActorRef,
val channel: SocketChannel) extends Actor with ThreadLocalDirectBuffer with ActorLogging {
val tcp = Tcp(context.system)
channel.configureBlocking(false)
var pendingWrite: Write = Write.Empty // a write "queue" of size 1 for holding one unfinished write command
var pendingWriteCommander: ActorRef = null
// Needed to send the ConnectionClosed message in the postStop handler.
// First element is the handler, second the particular close message.
var closedMessage: (ActorRef, ConnectionClosed) = null
def writePending = pendingWrite ne Write.Empty
def registerTimeout = Tcp(context.system).Settings.RegisterTimeout
def registerTimeout = tcp.Settings.RegisterTimeout
def traceLoggingEnabled = tcp.Settings.TraceLogging
// STATES
/** connection established, waiting for registration from user handler */
def waitingForRegistration(commander: ActorRef): Receive = {
case Register(handler)
log.debug("{} registered as connection handler", handler)
if (traceLoggingEnabled) log.debug("{} registered as connection handler", handler)
selector ! ReadInterest
context.setReceiveTimeout(Duration.Undefined)
@ -44,8 +53,8 @@ abstract class TcpConnection(val selector: ActorRef,
handleClose(commander, closeResponse(cmd))
case ReceiveTimeout
// TODO: just shutting down, as we do here, presents a race condition to the user
// Should we introduce a dedicated `Registered` event message to notify the user of successful registration?
// after sending `Register` user should watch this actor to make sure
// it didn't die because of the timeout
log.warning("Configured registration timeout of {} expired, stopping", registerTimeout)
context.stop(self)
}
@ -57,11 +66,18 @@ abstract class TcpConnection(val selector: ActorRef,
case ChannelReadable doRead(handler)
case write: Write if writePending
log.debug("Dropping write because queue is full")
handler ! CommandFailed(write)
if (traceLoggingEnabled) log.debug("Dropping write because queue is full")
sender ! CommandFailed(write)
case write: Write doWrite(handler, write)
case ChannelWritable doWrite(handler, pendingWrite)
case write: Write if write.data.isEmpty
if (write.wantsAck)
sender ! write.ack
case write: Write
pendingWriteCommander = sender
pendingWrite = write
doWrite(handler)
case ChannelWritable doWrite(handler)
case cmd: CloseCommand handleClose(handler, closeResponse(cmd))
}
@ -73,7 +89,7 @@ abstract class TcpConnection(val selector: ActorRef,
case ChannelReadable doRead(handler)
case ChannelWritable
doWrite(handler, pendingWrite)
doWrite(handler)
if (!writePending) // writing is now finished
handleClose(handler, closedEvent)
@ -111,17 +127,17 @@ abstract class TcpConnection(val selector: ActorRef,
buffer.flip()
if (readBytes > 0) {
log.debug("Read {} bytes", readBytes)
handler ! Received(ByteString(buffer).take(readBytes))
if (traceLoggingEnabled) log.debug("Read {} bytes", readBytes)
handler ! Received(ByteString(buffer))
if (readBytes == buffer.capacity())
// directly try reading more because we exhausted our buffer
self ! ChannelReadable
else selector ! ReadInterest
} else if (readBytes == 0) {
log.debug("Read nothing. Registering read interest with selector")
if (traceLoggingEnabled) log.debug("Read nothing. Registering read interest with selector")
selector ! ReadInterest
} else if (readBytes == -1) {
log.debug("Read returned end-of-stream")
if (traceLoggingEnabled) log.debug("Read returned end-of-stream")
doCloseConnection(handler, closeReason)
} else throw new IllegalStateException("Unexpected value returned from read: " + readBytes)
@ -130,7 +146,8 @@ abstract class TcpConnection(val selector: ActorRef,
}
}
def doWrite(handler: ActorRef, write: Write): Unit = {
def doWrite(handler: ActorRef): Unit = {
val write = pendingWrite
val data = write.data
val buffer = directBuffer()
@ -138,13 +155,15 @@ abstract class TcpConnection(val selector: ActorRef,
buffer.flip()
try {
log.debug("Trying to write to channel")
val writtenBytes = channel.write(buffer)
log.debug("Wrote {} bytes", writtenBytes)
if (traceLoggingEnabled) log.debug("Wrote {} bytes", writtenBytes)
pendingWrite = consume(write, writtenBytes)
if (writePending) selector ! WriteInterest // still data to write
else if (write.ack != null) handler ! write.ack // everything written
else if (write.wantsAck) {
pendingWriteCommander ! write.ack
pendingWriteCommander = null
} // everything written
} catch {
case e: IOException handleError(handler, e)
}
@ -156,20 +175,20 @@ abstract class TcpConnection(val selector: ActorRef,
def handleClose(handler: ActorRef, closedEvent: ConnectionClosed): Unit =
if (closedEvent == Aborted) { // close instantly
log.debug("Got Abort command. RESETing connection.")
if (traceLoggingEnabled) log.debug("Got Abort command. RESETing connection.")
doCloseConnection(handler, closedEvent)
} else if (writePending) { // finish writing first
log.debug("Got Close command but write is still pending.")
if (traceLoggingEnabled) log.debug("Got Close command but write is still pending.")
context.become(closingWithPendingWrite(handler, closedEvent))
} else if (closedEvent == ConfirmedClosed) { // shutdown output and wait for confirmation
log.debug("Got ConfirmedClose command, sending FIN.")
if (traceLoggingEnabled) log.debug("Got ConfirmedClose command, sending FIN.")
channel.socket.shutdownOutput()
context.become(closing(handler))
} else { // close now
log.debug("Got Close command, closing connection.")
if (traceLoggingEnabled) log.debug("Got Close command, closing connection.")
doCloseConnection(handler, closedEvent)
}
@ -177,7 +196,8 @@ abstract class TcpConnection(val selector: ActorRef,
if (closedEvent == Aborted) abort()
else channel.close()
handler ! closedEvent
closedMessage = (handler, closedEvent)
context.stop(self)
}
@ -189,10 +209,18 @@ abstract class TcpConnection(val selector: ActorRef,
}
def handleError(handler: ActorRef, exception: IOException): Unit = {
exception.setStackTrace(Array.empty)
handler ! ErrorClose(exception)
closedMessage = (handler, ErrorClose(extractMsg(exception)))
throw exception
}
@tailrec private[this] def extractMsg(t: Throwable): String =
if (t == null) "unknown"
else {
t.getMessage match {
case null | "" extractMsg(t.getCause)
case msg msg
}
}
def abort(): Unit = {
try channel.socket.setSoLinger(true, 0) // causes the following close() to send TCP RST
@ -200,26 +228,34 @@ abstract class TcpConnection(val selector: ActorRef,
case NonFatal(e)
// setSoLinger can fail due to http://bugs.sun.com/view_bug.do?bug_id=6799574
// (also affected: OS/X Java 1.6.0_37)
log.debug("setSoLinger(true, 0) failed with {}", e)
if (traceLoggingEnabled) log.debug("setSoLinger(true, 0) failed with {}", e)
}
channel.close()
}
override def postStop(): Unit =
override def postStop(): Unit = {
if (closedMessage != null) {
val msg = closedMessage._2
closedMessage._1 ! msg
if (writePending)
pendingWriteCommander ! msg
}
if (channel.isOpen)
abort()
}
override def postRestart(reason: Throwable): Unit =
throw new IllegalStateException("Restarting not supported for connection actors.")
/** Returns a new write with `numBytes` removed from the front */
def consume(write: Write, numBytes: Int): Write =
write match {
case Write.Empty if numBytes == 0 write
numBytes match {
case 0 write
case x if x == write.data.length Write.Empty
case _
numBytes match {
case 0 write
case x if x == write.data.length Write.Empty
case _
require(numBytes > 0 && numBytes < write.data.length)
write.copy(data = write.data.drop(numBytes))
}
require(numBytes > 0 && numBytes < write.data.length)
write.copy(data = write.data.drop(numBytes))
}
}