diff --git a/akka-actor/src/main/scala/akka/dispatch/BatchingExecutor.scala b/akka-actor/src/main/scala/akka/dispatch/BatchingExecutor.scala index 68099df513..afeb6b93dd 100644 --- a/akka-actor/src/main/scala/akka/dispatch/BatchingExecutor.scala +++ b/akka-actor/src/main/scala/akka/dispatch/BatchingExecutor.scala @@ -13,6 +13,8 @@ import java.util.concurrent.{ Callable, Executor, TimeUnit } import scala.concurrent.util.Duration import scala.concurrent._ +private[akka] trait Batchable { self: Runnable ⇒ } + /** * Mixin trait for an Executor * which groups multiple nested `Runnable.run()` calls @@ -40,17 +42,14 @@ import scala.concurrent._ */ private[akka] trait BatchingExecutor extends Executor { - // invariant: if "_tasksLocal.get ne null" then we are inside - // BatchingRunnable.run; if it is null, we are outside + // invariant: if "_tasksLocal.get ne null" then we are inside BatchingRunnable.run; if it is null, we are outside private val _tasksLocal = new ThreadLocal[List[Runnable]]() // only valid to call if _tasksLocal.get ne null - private def push(runnable: Runnable): Unit = - _tasksLocal.set(runnable :: _tasksLocal.get) + private def push(runnable: Runnable): Unit = _tasksLocal.set(runnable :: _tasksLocal.get) // only valid to call if _tasksLocal.get ne null - private def nonEmpty(): Boolean = - _tasksLocal.get.nonEmpty + private def nonEmpty(): Boolean = _tasksLocal.get.nonEmpty // only valid to call if _tasksLocal.get ne null private def pop(): Runnable = { @@ -59,38 +58,18 @@ private[akka] trait BatchingExecutor extends Executor { tasks.head } - private class BatchingBlockContext(previous: BlockContext) extends BlockContext { - - override def blockOn[T](thunk: ⇒ T)(implicit permission: CanAwait): T = { - // if we know there will be blocking, we don't want to - // keep tasks queued up because it could deadlock. - _tasksLocal.get match { - case null | Nil ⇒ - // null = not inside a BatchingRunnable - // Nil = inside a BatchingRunnable, but nothing is queued up - case list ⇒ - // inside a BatchingRunnable and there's a queue; - // make a new BatchingRunnable and send it to - // another thread - _tasksLocal set Nil - unbatchedExecute(new BatchingRunnable(list)) - } - - // now delegate the blocking to the previous BC - previous.blockOn(thunk) - } - } - - private class BatchingRunnable(val initial: List[Runnable]) extends Runnable { + private class Batch(val initial: List[Runnable]) extends Runnable with BlockContext { + private var parentBlockContext: BlockContext = _ // this method runs in the delegate ExecutionContext's thread override def run(): Unit = { require(_tasksLocal.get eq null) - val bc = new BatchingBlockContext(BlockContext.current) - BlockContext.withBlockContext(bc) { + val prevBlockContext = BlockContext.current + BlockContext.withBlockContext(this) { try { + parentBlockContext = prevBlockContext _tasksLocal set initial - while (nonEmpty) { + while (nonEmpty()) { val next = pop() try { next.run() @@ -102,34 +81,44 @@ private[akka] trait BatchingExecutor extends Executor { // up to the invoking executor val remaining = _tasksLocal.get _tasksLocal set Nil - unbatchedExecute(new BatchingRunnable(remaining)) + unbatchedExecute(new Batch(remaining)) //TODO what if this submission fails? throw t // rethrow } } } finally { _tasksLocal.remove() + parentBlockContext = null require(_tasksLocal.get eq null) } } } - } - private[this] def unbatchedExecute(r: Runnable): Unit = super.execute(r) + override def blockOn[T](thunk: ⇒ T)(implicit permission: CanAwait): T = { + // if we know there will be blocking, we don't want to keep tasks queued up because it could deadlock. + { + val tasks = _tasksLocal.get + _tasksLocal.remove() + if ((tasks ne null) && tasks.nonEmpty) + unbatchedExecute(new Batch(tasks)) + } - abstract override def execute(runnable: Runnable): Unit = { - _tasksLocal.get match { - case null ⇒ - // outside BatchingRunnable.run: start a new batch - unbatchedExecute(runnable) - case _ ⇒ - // inside BatchingRunnable.run - if (batchable(runnable)) - push(runnable) // add to existing batch - else - unbatchedExecute(runnable) // bypass batching mechanism + // now delegate the blocking to the previous BC + require(parentBlockContext ne null) + parentBlockContext.blockOn(thunk) } } + protected def unbatchedExecute(r: Runnable): Unit = super.execute(r) + + abstract override def execute(runnable: Runnable): Unit = { + if (batchable(runnable)) { // If we can batch the runnable + _tasksLocal.get match { + case null ⇒ unbatchedExecute(new Batch(List(runnable))) // If we aren't in batching mode yet, enqueue batch + case some ⇒ push(runnable) // If we are already in batching mode, add to batch + } + } else unbatchedExecute(runnable) // If not batchable, just delegate to underlying + } + /** Override this to define which runnables will be batched. */ - def batchable(runnable: Runnable): Boolean + def batchable(runnable: Runnable): Boolean = runnable.isInstanceOf[Batchable] }