From da987138dd1d38ec041e8a5bee7933da2a3771e1 Mon Sep 17 00:00:00 2001 From: Derek Williams Date: Tue, 26 Jul 2011 22:23:16 -0600 Subject: [PATCH] Partial fix for ticket #1054: execute callbacks in dispatcher --- .../test/scala/akka/dispatch/FutureSpec.scala | 21 ++-- .../main/scala/akka/dispatch/Dispatcher.scala | 2 +- .../src/main/scala/akka/dispatch/Future.scala | 95 +++++++++++-------- .../scala/akka/dispatch/MessageHandling.scala | 42 ++++---- .../testkit/CallingThreadDispatcher.scala | 4 +- 5 files changed, 89 insertions(+), 75 deletions(-) diff --git a/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala b/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala index f30a331981..c7696b6215 100644 --- a/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala @@ -391,7 +391,7 @@ class FutureSpec extends WordSpec with MustMatchers with Checkers with BeforeAnd val f1 = Future { throw new ThrowableTest("test") } f1.await - intercept[ThrowableTest] { f1.resultOrException } + intercept[ThrowableTest] { f1.get } val latch = new StandardLatch val f2 = Future { latch.tryAwait(5, TimeUnit.SECONDS); "success" } @@ -400,14 +400,17 @@ class FutureSpec extends WordSpec with MustMatchers with Checkers with BeforeAnd val f3 = f2 map (s ⇒ s.toUpperCase) latch.open f2.await - assert(f2.resultOrException === Some("success")) + assert(f2.get === "success") f2 foreach (_ ⇒ throw new ThrowableTest("current thread foreach")) f2 onResult { case _ ⇒ throw new ThrowableTest("current thread receive") } f3.await - assert(f3.resultOrException === Some("SUCCESS")) + assert(f3.get === "SUCCESS") + + // give time for all callbacks to execute + Thread sleep 100 // make sure all futures are completed in dispatcher - assert(Dispatchers.defaultGlobalDispatcher.pendingFutures === 0) + assert(Dispatchers.defaultGlobalDispatcher.pendingTasks === 0) } "shouldBlockUntilResult" in { @@ -519,7 +522,7 @@ class FutureSpec extends WordSpec with MustMatchers with Checkers with BeforeAnd Thread.sleep(100) // make sure all futures are completed in dispatcher - assert(Dispatchers.defaultGlobalDispatcher.pendingFutures === 0) + assert(Dispatchers.defaultGlobalDispatcher.pendingTasks === 0) } "shouldNotAddOrRunCallbacksAfterFailureToBeCompletedBeforeExpiry" in { @@ -726,12 +729,12 @@ class FutureSpec extends WordSpec with MustMatchers with Checkers with BeforeAnd "ticket812FutureDispatchCleanup" in { implicit val dispatcher = new Dispatcher("ticket812FutureDispatchCleanup") - assert(dispatcher.pendingFutures === 0) + assert(dispatcher.pendingTasks === 0) val future = Future({ Thread.sleep(100); "Done" }, 10) intercept[FutureTimeoutException] { future.await } - assert(dispatcher.pendingFutures === 1) - Thread.sleep(100) - assert(dispatcher.pendingFutures === 0) + assert(dispatcher.pendingTasks === 1) + Thread.sleep(200) + assert(dispatcher.pendingTasks === 0) } } } diff --git a/akka-actor/src/main/scala/akka/dispatch/Dispatcher.scala b/akka-actor/src/main/scala/akka/dispatch/Dispatcher.scala index fbd32c580b..d2b880041c 100644 --- a/akka-actor/src/main/scala/akka/dispatch/Dispatcher.scala +++ b/akka-actor/src/main/scala/akka/dispatch/Dispatcher.scala @@ -96,7 +96,7 @@ class Dispatcher( registerForExecution(mbox) } - private[akka] def executeFuture(invocation: FutureInvocation[_]): Unit = if (active.isOn) { + private[akka] def executeTask(invocation: TaskInvocation): Unit = if (active.isOn) { try executorService.get() execute invocation catch { case e: RejectedExecutionException ⇒ diff --git a/akka-actor/src/main/scala/akka/dispatch/Future.scala b/akka-actor/src/main/scala/akka/dispatch/Future.scala index 547034fd0b..4a378c599a 100644 --- a/akka-actor/src/main/scala/akka/dispatch/Future.scala +++ b/akka-actor/src/main/scala/akka/dispatch/Future.scala @@ -238,8 +238,19 @@ object Future { * This method constructs and returns a Future that will eventually hold the result of the execution of the supplied body * The execution is performed by the specified Dispatcher. */ - def apply[T](body: ⇒ T)(implicit dispatcher: MessageDispatcher, timeout: Timeout = implicitly): Future[T] = - dispatcher.dispatchFuture(() ⇒ body, timeout) + def apply[T](body: ⇒ T)(implicit dispatcher: MessageDispatcher, timeout: Timeout = implicitly): Future[T] = { + val promise = new DefaultPromise[T](timeout) + dispatcher dispatchTask { () ⇒ + promise complete { + try { + Right(body) + } catch { + case e ⇒ Left(e) + } + } + } + promise + } def apply[T](body: ⇒ T, timeout: Timeout)(implicit dispatcher: MessageDispatcher): Future[T] = apply(body)(dispatcher, timeout) @@ -293,9 +304,13 @@ object Future { * * The Delimited Continuations compiler plugin must be enabled in order to use this method. */ - def flow[A](body: ⇒ A @cps[Future[Any]])(implicit timeout: Timeout): Future[A] = { + def flow[A](body: ⇒ A @cps[Future[Any]])(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[A] = { val future = Promise[A](timeout) - (reset(future.asInstanceOf[Promise[Any]].completeWithResult(body)): Future[Any]) onException { case e ⇒ future completeWithException e } + //dispatcher dispatchTask { () ⇒ + reify(body) foreachFull (future completeWithResult, future completeWithException) onException { + case e: Exception ⇒ future completeWithException e + } + //} future } } @@ -312,7 +327,7 @@ sealed trait Future[+T] { * execution will fail. The normal result of getting a Future from an ActorRef using ? will return * an untyped Future. */ - def apply[A >: T](): A @cps[Future[Any]] = shift(this flatMap (_: A ⇒ Future[Any])) + def apply[A >: T]()(implicit dispatcher: MessageDispatcher, timeout: Timeout): A @cps[Future[Any]] = shift(this flatMap (_: A ⇒ Future[Any])) /** * Blocks awaiting completion of this Future, then returns the resulting value, @@ -407,7 +422,7 @@ sealed trait Future[+T] { * Future. If the Future has already been completed, this will apply * immediately. */ - def onComplete(func: Future[T] ⇒ Unit): this.type + def onComplete(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type /** * When the future is completed with a valid result, apply the provided @@ -419,7 +434,7 @@ sealed trait Future[+T] { * } * */ - final def onResult(pf: PartialFunction[Any, Unit]): this.type = onComplete { f ⇒ + final def onResult(pf: PartialFunction[Any, Unit])(implicit dispatcher: MessageDispatcher): this.type = onComplete { f ⇒ val optr = f.result if (optr.isDefined) { val r = optr.get @@ -437,7 +452,7 @@ sealed trait Future[+T] { * } * */ - final def onException(pf: PartialFunction[Throwable, Unit]): Future[T] = onComplete { f ⇒ + final def onException(pf: PartialFunction[Throwable, Unit])(implicit dispatcher: MessageDispatcher): Future[T] = onComplete { f ⇒ val opte = f.exception if (opte.isDefined) { val e = opte.get @@ -445,9 +460,9 @@ sealed trait Future[+T] { } } - def onTimeout(func: Future[T] ⇒ Unit): this.type + def onTimeout(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type - def orElse[A >: T](fallback: ⇒ A): Future[A] + def orElse[A >: T](fallback: ⇒ A)(implicit dispatcher: MessageDispatcher): Future[A] /** * Creates a new Future by applying a PartialFunction to the successful @@ -463,7 +478,7 @@ sealed trait Future[+T] { * } yield b + "-" + c * */ - final def collect[A](pf: PartialFunction[Any, A])(implicit timeout: Timeout): Future[A] = value match { + final def collect[A](pf: PartialFunction[Any, A])(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[A] = value match { case Some(Right(r)) ⇒ new KeptPromise[A](try { if (pf isDefinedAt r) @@ -509,7 +524,7 @@ sealed trait Future[+T] { * Future(6 / 2) recover { case e: ArithmeticException => 0 } // result: 3 * */ - final def recover[A >: T](pf: PartialFunction[Throwable, A])(implicit timeout: Timeout): Future[A] = value match { + final def recover[A >: T](pf: PartialFunction[Throwable, A])(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[A] = value match { case Some(Left(e)) ⇒ try { if (pf isDefinedAt e) @@ -556,7 +571,7 @@ sealed trait Future[+T] { * } yield b + "-" + c * */ - final def map[A](f: T ⇒ A)(implicit timeout: Timeout): Future[A] = value match { + final def map[A](f: T ⇒ A)(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[A] = value match { case Some(Right(r)) ⇒ new KeptPromise[A](try { Right(f(r)) @@ -591,7 +606,7 @@ sealed trait Future[+T] { * Creates a new Future[A] which is completed with this Future's result if * that conforms to A's erased type or a ClassCastException otherwise. */ - final def mapTo[A](implicit m: Manifest[A], timeout: Timeout = this.timeout): Future[A] = value match { + final def mapTo[A](implicit m: Manifest[A], dispatcher: MessageDispatcher = implicitly, timeout: Timeout = this.timeout): Future[A] = value match { case Some(Right(t)) ⇒ new KeptPromise(try { Right(BoxedType(m.erasure).cast(t).asInstanceOf[A]) @@ -630,7 +645,7 @@ sealed trait Future[+T] { * } yield b + "-" + c * */ - final def flatMap[A](f: T ⇒ Future[A])(implicit timeout: Timeout): Future[A] = value match { + final def flatMap[A](f: T ⇒ Future[A])(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[A] = value match { case Some(Right(r)) ⇒ try { f(r) @@ -659,23 +674,23 @@ sealed trait Future[+T] { future } - final def foreach(f: T ⇒ Unit): Unit = onComplete { + final def foreach(f: T ⇒ Unit)(implicit dispatcher: MessageDispatcher): Unit = onComplete { _.result match { case Some(v) ⇒ f(v) case None ⇒ } } - final def withFilter(p: T ⇒ Boolean) = new FutureWithFilter[T](this, p) + final def withFilter(p: T ⇒ Boolean)(implicit dispatcher: MessageDispatcher, timeout: Timeout) = new FutureWithFilter[T](this, p) - final class FutureWithFilter[+A](self: Future[A], p: A ⇒ Boolean)(implicit timeout: Timeout) { + final class FutureWithFilter[+A](self: Future[A], p: A ⇒ Boolean)(implicit dispatcher: MessageDispatcher, timeout: Timeout) { def foreach(f: A ⇒ Unit): Unit = self filter p foreach f def map[B](f: A ⇒ B): Future[B] = self filter p map f def flatMap[B](f: A ⇒ Future[B]): Future[B] = self filter p flatMap f def withFilter(q: A ⇒ Boolean): FutureWithFilter[A] = new FutureWithFilter[A](self, x ⇒ p(x) && q(x)) } - final def filter(p: T ⇒ Boolean)(implicit timeout: Timeout): Future[T] = value match { + final def filter(p: T ⇒ Boolean)(implicit dispatcher: MessageDispatcher, timeout: Timeout): Future[T] = value match { case Some(Right(r)) ⇒ try { if (p(r)) @@ -767,26 +782,26 @@ trait Promise[T] extends Future[T] { * Completes this Future with the specified result, if not already completed. * @return this */ - def complete(value: Either[Throwable, T]): this.type + def complete(value: Either[Throwable, T])(implicit dispatcher: MessageDispatcher): this.type /** * Completes this Future with the specified result, if not already completed. * @return this */ - final def completeWithResult(result: T): this.type = complete(Right(result)) + final def completeWithResult(result: T)(implicit dispatcher: MessageDispatcher): this.type = complete(Right(result)) /** * Completes this Future with the specified exception, if not already completed. * @return this */ - final def completeWithException(exception: Throwable): this.type = complete(Left(exception)) + final def completeWithException(exception: Throwable)(implicit dispatcher: MessageDispatcher): this.type = complete(Left(exception)) /** * Completes this Future with the specified other Future, when that Future is completed, * unless this Future has already been completed. * @return this. */ - final def completeWith(other: Future[T]): this.type = { + final def completeWith(other: Future[T])(implicit dispatcher: MessageDispatcher): this.type = { other onComplete { f ⇒ complete(f.value.get) } this } @@ -794,7 +809,7 @@ trait Promise[T] extends Future[T] { final def <<(value: T): Future[T] @cps[Future[Any]] = shift { cont: (Future[T] ⇒ Future[Any]) ⇒ cont(complete(Right(value))) } final def <<(other: Future[T]): Future[T] @cps[Future[Any]] = shift { cont: (Future[T] ⇒ Future[Any]) ⇒ - val fr = new DefaultPromise[Any]() + val fr = new DefaultPromise[Any](this.timeout) this completeWith other onComplete { f ⇒ try { fr completeWith cont(f) @@ -808,7 +823,7 @@ trait Promise[T] extends Future[T] { } final def <<(stream: PromiseStreamOut[T]): Future[T] @cps[Future[Any]] = shift { cont: (Future[T] ⇒ Future[Any]) ⇒ - val fr = Promise[Any]() + val fr = new DefaultPromise[Any](this.timeout) stream.dequeue(this).onComplete { f ⇒ try { fr completeWith cont(f) @@ -892,7 +907,7 @@ class DefaultPromise[T](val timeout: Timeout) extends Promise[T] { } } - def complete(value: Either[Throwable, T]): this.type = { + def complete(value: Either[Throwable, T])(implicit dispatcher: MessageDispatcher): this.type = { _lock.lock val notifyTheseListeners = try { if (_value.isEmpty) { //Only complete if we aren't expired @@ -928,18 +943,20 @@ class DefaultPromise[T](val timeout: Timeout) extends Promise[T] { notifyTheseListeners foreach doNotify }) } else { - try { - val callbacks = Stack[() ⇒ Unit]() // Allocate new aggregator for pending callbacks - Promise.callbacksPendingExecution.set(Some(callbacks)) // Specify the callback aggregator - runCallbacks(notifyTheseListeners, callbacks) // Execute callbacks, if they trigger new callbacks, they are aggregated - } finally { Promise.callbacksPendingExecution.set(None) } // Ensure cleanup + dispatcher dispatchTask { () ⇒ + try { + val callbacks = Stack[() ⇒ Unit]() // Allocate new aggregator for pending callbacks + Promise.callbacksPendingExecution.set(Some(callbacks)) // Specify the callback aggregator + runCallbacks(notifyTheseListeners, callbacks) // Execute callbacks, if they trigger new callbacks, they are aggregated + } finally { Promise.callbacksPendingExecution.set(None) } // Ensure cleanup + } } } this } - def onComplete(func: Future[T] ⇒ Unit): this.type = { + def onComplete(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type = { _lock.lock val notifyNow = try { if (_value.isEmpty) { @@ -952,12 +969,12 @@ class DefaultPromise[T](val timeout: Timeout) extends Promise[T] { _lock.unlock } - if (notifyNow) notifyCompleted(func) + if (notifyNow) dispatcher dispatchTask (() ⇒ notifyCompleted(func)) this } - def onTimeout(func: Future[T] ⇒ Unit): this.type = { + def onTimeout(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type = { if (timeout.duration.isFinite) { _lock.lock val runNow = try { @@ -982,7 +999,7 @@ class DefaultPromise[T](val timeout: Timeout) extends Promise[T] { this } - final def orElse[A >: T](fallback: ⇒ A): Future[A] = + final def orElse[A >: T](fallback: ⇒ A)(implicit dispatcher: MessageDispatcher): Future[A] = if (timeout.duration.isFinite) { value match { case Some(_) ⇒ this @@ -1047,14 +1064,14 @@ object ActorPromise { sealed class KeptPromise[T](suppliedValue: Either[Throwable, T]) extends Promise[T] { val value = Some(suppliedValue) - def complete(value: Either[Throwable, T]): this.type = this - def onComplete(func: Future[T] ⇒ Unit): this.type = { func(this); this } + def complete(value: Either[Throwable, T])(implicit dispatcher: MessageDispatcher): this.type = this + def onComplete(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type = { func(this); this } def await(atMost: Duration): this.type = this def await: this.type = this def isExpired: Boolean = true def timeout: Timeout = Timeout.zero - final def onTimeout(func: Future[T] ⇒ Unit): this.type = this - final def orElse[A >: T](fallback: ⇒ A): Future[A] = this + final def onTimeout(func: Future[T] ⇒ Unit)(implicit dispatcher: MessageDispatcher): this.type = this + final def orElse[A >: T](fallback: ⇒ A)(implicit dispatcher: MessageDispatcher): Future[A] = this } diff --git a/akka-actor/src/main/scala/akka/dispatch/MessageHandling.scala b/akka-actor/src/main/scala/akka/dispatch/MessageHandling.scala index aafd0988ee..7f4437ed4c 100644 --- a/akka-actor/src/main/scala/akka/dispatch/MessageHandling.scala +++ b/akka-actor/src/main/scala/akka/dispatch/MessageHandling.scala @@ -26,17 +26,15 @@ final case class MessageInvocation(val receiver: ActorRef, } } -final case class FutureInvocation[T](future: Promise[T], function: () ⇒ T, cleanup: () ⇒ Unit) extends Runnable { +final case class TaskInvocation(function: () ⇒ Unit, cleanup: () ⇒ Unit) extends Runnable { def run() { - future complete (try { - Right(function()) + try { + function() } catch { - case e ⇒ - EventHandler.error(e, this, e.getMessage) - Left(e) + case e ⇒ EventHandler.error(e, this, e.getMessage) } finally { cleanup() - }) + } } } @@ -56,7 +54,7 @@ trait MessageDispatcher { import MessageDispatcher._ protected val uuids = new ConcurrentSkipListSet[Uuid] - protected val futures = new AtomicLong(0L) + protected val tasks = new AtomicLong(0L) protected val guard = new ReentrantGuard protected val active = new Switch(false) @@ -94,31 +92,27 @@ trait MessageDispatcher { dispatch(invocation) } - private[akka] final def dispatchFuture[T](block: () ⇒ T, timeout: Timeout): Future[T] = { - futures.getAndIncrement() + private[akka] final def dispatchTask(block: () ⇒ Unit): Unit = { + tasks.getAndIncrement() try { - val future = new DefaultPromise[T](timeout) - if (active.isOff) guard withGuard { active.switchOn { start() } } - - executeFuture(FutureInvocation[T](future, block, futureCleanup)) - future + executeTask(TaskInvocation(block, taskCleanup)) } catch { case e ⇒ - futures.decrementAndGet + tasks.decrementAndGet throw e } } - private val futureCleanup: () ⇒ Unit = - () ⇒ if (futures.decrementAndGet() == 0) { + private val taskCleanup: () ⇒ Unit = + () ⇒ if (tasks.decrementAndGet() == 0) { guard withGuard { - if (futures.get == 0 && uuids.isEmpty) { + if (tasks.get == 0 && uuids.isEmpty) { shutdownSchedule match { case UNSCHEDULED ⇒ shutdownSchedule = SCHEDULED @@ -155,7 +149,7 @@ trait MessageDispatcher { if (uuids remove actorRef.uuid) { cleanUpMailboxFor(actorRef) actorRef.mailbox = null - if (uuids.isEmpty && futures.get == 0) { + if (uuids.isEmpty && tasks.get == 0) { shutdownSchedule match { case UNSCHEDULED ⇒ shutdownSchedule = SCHEDULED @@ -196,7 +190,7 @@ trait MessageDispatcher { shutdownSchedule = SCHEDULED Scheduler.scheduleOnce(this, timeoutMs, TimeUnit.MILLISECONDS) case SCHEDULED ⇒ - if (uuids.isEmpty && futures.get == 0) { + if (uuids.isEmpty && tasks.get == 0) { active switchOff { shutdown() // shut down in the dispatcher's references is zero } @@ -229,7 +223,7 @@ trait MessageDispatcher { */ private[akka] def dispatch(invocation: MessageInvocation) - private[akka] def executeFuture(invocation: FutureInvocation[_]) + private[akka] def executeTask(invocation: TaskInvocation) /** * Called one time every time an actor is attached to this dispatcher and this dispatcher was previously shutdown @@ -252,9 +246,9 @@ trait MessageDispatcher { def mailboxIsEmpty(actorRef: ActorRef): Boolean /** - * Returns the amount of futures queued for execution + * Returns the amount of tasks queued for execution */ - def pendingFutures: Long = futures.get + def pendingTasks: Long = tasks.get } /** diff --git a/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala b/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala index 0f0344cc49..fc89966964 100644 --- a/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala +++ b/akka-testkit/src/main/scala/akka/testkit/CallingThreadDispatcher.scala @@ -5,7 +5,7 @@ package akka.testkit import akka.event.EventHandler import akka.actor.ActorRef -import akka.dispatch.{ MessageDispatcher, MessageInvocation, FutureInvocation, Promise, ActorPromise } +import akka.dispatch.{ MessageDispatcher, MessageInvocation, TaskInvocation, Promise, ActorPromise } import java.util.concurrent.locks.ReentrantLock import java.util.LinkedList import java.util.concurrent.RejectedExecutionException @@ -161,7 +161,7 @@ class CallingThreadDispatcher(val name: String = "calling-thread", val warnings: if (execute) runQueue(mbox, queue) } - private[akka] override def executeFuture(invocation: FutureInvocation[_]) { invocation.run } + private[akka] override def executeTask(invocation: TaskInvocation) { invocation.run } /* * This method must be called with this thread's queue, which must already