diff --git a/akka-actor/src/main/scala/akka/dispatch/Mailbox.scala b/akka-actor/src/main/scala/akka/dispatch/Mailbox.scala index e7dc0682d2..e055c4b327 100644 --- a/akka-actor/src/main/scala/akka/dispatch/Mailbox.scala +++ b/akka-actor/src/main/scala/akka/dispatch/Mailbox.scala @@ -388,12 +388,15 @@ private[akka] trait DefaultSystemMessageQueue { self: Mailbox ⇒ } @tailrec - final def systemDrain(newContents: SystemMessage): SystemMessage = { - val head = systemQueueGet - if (systemQueuePut(head, newContents)) SystemMessage.reverse(head) else systemDrain(newContents) + final def systemDrain(newContents: SystemMessage): SystemMessage = systemQueueGet match { + case NoMessage ⇒ null + case head ⇒ if (systemQueuePut(head, newContents)) SystemMessage.reverse(head) else systemDrain(newContents) } - def hasSystemMessages: Boolean = systemQueueGet ne null + def hasSystemMessages: Boolean = systemQueueGet match { + case null | NoMessage ⇒ false + case _ ⇒ true + } } diff --git a/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala b/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala index 5ecdcf3948..cd026e1902 100644 --- a/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala +++ b/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala @@ -7,7 +7,6 @@ import language.postfixOps import java.lang.ref.WeakReference import java.util.concurrent.locks.ReentrantLock -import java.util.LinkedList import scala.annotation.tailrec import com.typesafe.config.Config import akka.actor.{ ActorInitializationException, ExtensionIdProvider, ExtensionId, Extension, ExtendedActorSystem, ActorRef, ActorCell } @@ -15,25 +14,24 @@ import akka.dispatch.{ MessageQueue, MailboxType, TaskInvocation, SystemMessage, import scala.concurrent.duration._ import akka.util.Switch import scala.concurrent.duration.Duration -import scala.concurrent.Awaitable -import akka.actor.ActorContext import scala.util.control.NonFatal +import java.util.concurrent.TimeUnit /* * Locking rules: * - * While not suspendSwitch, messages are processed (!isActive) or queued - * thread-locally (isActive). While suspendSwitch, messages are queued - * thread-locally. When resuming, all messages are atomically scooped from all - * non-active threads and queued on the resuming thread's queue, to be - * processed immediately. Processing a queue checks suspend before each - * invocation, leaving the active state if suspendSwitch. For this to work - * reliably, the active flag needs to be set atomically with the initial check - * for suspend. Scooping up messages means replacing the ThreadLocal's contents - * with an empty new NestingQueue. + * Normal messages are always queued thread locally. + * Processing a queue checks suspendSwitch before each invocation, not processing + * if suspendSwitch. + * When resuming an actor, all messages are atomically scooped from all threads and + * queued on the resuming thread's queue, to be processed immediately. + * Scooping up messages means replacing the ThreadLocal contents with an empty + * new MessageQueue. * * All accesses to the queue must be done under the suspendSwitch-switch's lock, so * within one of its methods taking a closure argument. + * + * System messages always go directly to the actors SystemMessageQueue which isn't thread local. */ private[testkit] object CallingThreadDispatcherQueues extends ExtensionId[CallingThreadDispatcherQueues] with ExtensionIdProvider { @@ -45,19 +43,19 @@ private[testkit] class CallingThreadDispatcherQueues extends Extension { // PRIVATE DATA - private var queues = Map[CallingThreadMailbox, Set[WeakReference[NestingQueue]]]() + private var queues = Map[CallingThreadMailbox, Set[WeakReference[MessageQueue]]]() private var lastGC = 0l // we have to forget about long-gone threads sometime private def gc { - queues = (Map.newBuilder[CallingThreadMailbox, Set[WeakReference[NestingQueue]]] /: queues) { + queues = (Map.newBuilder[CallingThreadMailbox, Set[WeakReference[MessageQueue]]] /: queues) { case (m, (k, v)) ⇒ val nv = v filter (_.get ne null) if (nv.isEmpty) m else m += (k -> nv) }.result } - protected[akka] def registerQueue(mbox: CallingThreadMailbox, q: NestingQueue): Unit = synchronized { + protected[akka] def registerQueue(mbox: CallingThreadMailbox, q: MessageQueue): Unit = synchronized { if (queues contains mbox) { val newSet = queues(mbox) + new WeakReference(q) queues += mbox -> newSet @@ -80,8 +78,7 @@ private[testkit] class CallingThreadDispatcherQueues extends Extension { * given mailbox. When this method returns, the queue will be entered * (active). */ - protected[akka] def gatherFromAllOtherQueues(mbox: CallingThreadMailbox, own: NestingQueue): Unit = synchronized { - if (!own.isActive) own.enter + protected[akka] def gatherFromAllOtherQueues(mbox: CallingThreadMailbox, own: MessageQueue): Unit = synchronized { if (queues contains mbox) { for { ref ← queues(mbox) @@ -89,11 +86,11 @@ private[testkit] class CallingThreadDispatcherQueues extends Extension { if (q ne null) && (q ne own) } { val owner = mbox.actor.self - var msg = q.q.dequeue() + var msg = q.dequeue() while (msg ne null) { // this is safe because this method is only ever called while holding the suspendSwitch monitor - own.q.enqueue(owner, msg) - msg = q.q.dequeue() + own.enqueue(owner, msg) + msg = q.dequeue() } } } @@ -153,7 +150,6 @@ class CallingThreadDispatcher( actor.mailbox match { case mbox: CallingThreadMailbox ⇒ val queue = mbox.queue - queue.enter runQueue(mbox, queue) case x ⇒ throw ActorInitializationException("expected CallingThreadMailbox, got " + x.getClass) } @@ -181,14 +177,12 @@ class CallingThreadDispatcher( actor.mailbox match { case mbox: CallingThreadMailbox ⇒ val queue = mbox.queue - val wasActive = queue.isActive val switched = mbox.suspendSwitch.switchOff { CallingThreadDispatcherQueues(actor.system).gatherFromAllOtherQueues(mbox, queue) mbox.resume() } - if (switched && !wasActive) { + if (switched) runQueue(mbox, queue) - } case m ⇒ m.systemEnqueue(actor.self, Resume(causedByFailure = null)) } } @@ -197,11 +191,7 @@ class CallingThreadDispatcher( receiver.mailbox match { case mbox: CallingThreadMailbox ⇒ mbox.systemEnqueue(receiver.self, message) - val queue = mbox.queue - if (!queue.isActive) { - queue.enter - runQueue(mbox, queue) - } + runQueue(mbox, mbox.queue) case m ⇒ m.systemEnqueue(receiver.self, message) } } @@ -211,14 +201,11 @@ class CallingThreadDispatcher( case mbox: CallingThreadMailbox ⇒ val queue = mbox.queue val execute = mbox.suspendSwitch.fold { - queue.q.enqueue(receiver.self, handle) + queue.enqueue(receiver.self, handle) false } { - queue.q.enqueue(receiver.self, handle) - if (!queue.isActive) { - queue.enter - true - } else false + queue.enqueue(receiver.self, handle) + true } if (execute) runQueue(mbox, queue) case m ⇒ m.enqueue(receiver.self, handle) @@ -228,64 +215,86 @@ class CallingThreadDispatcher( protected[akka] override def executeTask(invocation: TaskInvocation) { invocation.run } /* - * This method must be called with this thread's queue, which must already - * have been entered (active). When this method returns, the queue will be - * inactive. + * This method must be called with this thread's queue. * * If the catch block is executed, then a non-empty mailbox may be stalled as * there is no-one who cares to execute it before the next message is sent or * it is suspendSwitch and resumed. */ @tailrec - private def runQueue(mbox: CallingThreadMailbox, queue: NestingQueue, interruptedex: InterruptedException = null) { - var intex = interruptedex; - assert(queue.isActive) - mbox.ctdLock.lock - val recurse = try { - mbox.processAllSystemMessages() - val handle = mbox.suspendSwitch.fold[Envelope] { - queue.leave - null - } { - val ret = if (mbox.isClosed) null else queue.q.dequeue() - if (ret eq null) queue.leave - ret - } - if (handle ne null) { - try { - if (Mailbox.debug) println(mbox.actor.self + " processing message " + handle) - mbox.actor.invoke(handle) - if (Thread.interrupted()) { // clear interrupted flag before we continue - intex = new InterruptedException("Interrupted during message processing") - log.error(intex, "Interrupted during message processing") - } - true - } catch { - case ie: InterruptedException ⇒ - log.error(ie, "Interrupted during message processing") - Thread.interrupted() // clear interrupted flag before continuing - intex = ie - true - case NonFatal(e) ⇒ - log.error(e, "Error during message processing") - queue.leave - false - } - } else if (queue.isActive) { - queue.leave - false - } else false - } catch { - case NonFatal(e) ⇒ queue.leave; throw e - } finally { - mbox.ctdLock.unlock + private def runQueue(mbox: CallingThreadMailbox, queue: MessageQueue, interruptedEx: InterruptedException = null) { + def checkThreadInterruption(intEx: InterruptedException): InterruptedException = { + if (Thread.interrupted()) { // clear interrupted flag before we continue, exception will be thrown later + val ie = new InterruptedException("Interrupted during message processing") + log.error(ie, "Interrupted during message processing") + ie + } else intEx } - if (recurse) { - runQueue(mbox, queue, intex) - } else { - if (intex ne null) { + + def throwInterruptionIfExistsOrSet(intEx: InterruptedException): Unit = { + val ie = checkThreadInterruption(intEx) + if (ie ne null) { Thread.interrupted() // clear interrupted flag before throwing according to java convention - throw intex + throw ie + } + } + + @tailrec + def process(intEx: InterruptedException): InterruptedException = { + var intex = intEx + val recurse = { + mbox.processAllSystemMessages() + val handle = mbox.suspendSwitch.fold[Envelope](null) { + if (mbox.isClosed) null else queue.dequeue() + } + if (handle ne null) { + try { + if (Mailbox.debug) println(mbox.actor.self + " processing message " + handle) + mbox.actor.invoke(handle) + intex = checkThreadInterruption(intex) + true + } catch { + case ie: InterruptedException ⇒ + log.error(ie, "Interrupted during message processing") + Thread.interrupted() // clear interrupted flag before we continue, exception will be thrown later + intex = ie + true + case NonFatal(e) ⇒ + log.error(e, "Error during message processing") + false + } + } else false + } + if (recurse) process(intex) + else intex + } + + // if we own the lock then we shouldn't do anything since we are processing + // this actors mailbox at some other level on our call stack + if (!mbox.ctdLock.isHeldByCurrentThread) { + var intex = interruptedEx + val gotLock = try { + mbox.ctdLock.tryLock(50, TimeUnit.MILLISECONDS) + } catch { + case ie: InterruptedException ⇒ + Thread.interrupted() // clear interrupted flag before we continue, exception will be thrown later + intex = ie + false + } + if (gotLock) { + val ie = try { + process(intex) + } finally { + mbox.ctdLock.unlock + } + throwInterruptionIfExistsOrSet(ie) + } else { + // if we didn't get the lock and our mailbox still has messages, then we need to try again + if (mbox.hasSystemMessages || mbox.hasMessages) { + runQueue(mbox, queue, intex) + } else { + throwInterruptionIfExistsOrSet(intex) + } } } } @@ -299,31 +308,23 @@ class CallingThreadDispatcherConfigurator(config: Config, prerequisites: Dispatc override def dispatcher(): MessageDispatcher = instance } -class NestingQueue(val q: MessageQueue) { - @volatile - private var active = false - def enter { if (active) sys.error("already active") else active = true } - def leave { if (!active) sys.error("not active") else active = false } - def isActive = active -} - class CallingThreadMailbox(_receiver: akka.actor.Cell, val mailboxType: MailboxType) extends Mailbox(null) with DefaultSystemMessageQueue { val system = _receiver.system val self = _receiver.self - private val q = new ThreadLocal[NestingQueue]() { + private val q = new ThreadLocal[MessageQueue]() { override def initialValue = { - val queue = new NestingQueue(mailboxType.create(Some(self), Some(system))) + val queue = mailboxType.create(Some(self), Some(system)) CallingThreadDispatcherQueues(system).registerQueue(CallingThreadMailbox.this, queue) queue } } - override def enqueue(receiver: ActorRef, msg: Envelope): Unit = q.get.q.enqueue(receiver, msg) + override def enqueue(receiver: ActorRef, msg: Envelope): Unit = q.get.enqueue(receiver, msg) override def dequeue(): Envelope = throw new UnsupportedOperationException("CallingThreadMailbox cannot dequeue normally") - override def hasMessages: Boolean = q.get.q.hasMessages + override def hasMessages: Boolean = q.get.hasMessages override def numberOfMessages: Int = 0 def queue = q.get @@ -341,7 +342,7 @@ class CallingThreadMailbox(_receiver: akka.actor.Cell, val mailboxType: MailboxT val qq = queue CallingThreadDispatcherQueues(actor.system).gatherFromAllOtherQueues(this, qq) super.cleanUp() - qq.q.cleanUp(actor.self, actor.systemImpl.deadLetterQueue) + qq.cleanUp(actor.self, actor.systemImpl.deadLetterQueue) q.remove() } }