diff --git a/akka-actor/src/main/scala/akka/dispatch/Future.scala b/akka-actor/src/main/scala/akka/dispatch/Future.scala index 2fc4026ce4..c741d86c7b 100644 --- a/akka-actor/src/main/scala/akka/dispatch/Future.scala +++ b/akka-actor/src/main/scala/akka/dispatch/Future.scala @@ -9,8 +9,10 @@ import akka.actor.Actor.spawn import akka.routing.Dispatcher import java.util.concurrent.locks.ReentrantLock -import java.util.concurrent.TimeUnit import akka.japi.Procedure +import java.util.concurrent. {ConcurrentLinkedQueue, TimeUnit} +import java.util.concurrent.atomic. {AtomicInteger} +import akka.actor.Actor class FutureTimeoutException(message: String) extends AkkaException(message) @@ -65,6 +67,32 @@ object Futures { * Returns Future.resultOrException of the first completed of the 2 Futures provided (blocking!) */ def awaitEither[T](f1: Future[T], f2: Future[T]): Option[T] = awaitOne(List(f1,f2)).asInstanceOf[Future[T]].resultOrException + + def fold[R,T](zero: R, timeout: Long = Actor.TIMEOUT)(futures: Traversable[Future[T]])(foldFun: (R, T) => R): Future[R] = { + val result = new DefaultCompletableFuture[R](timeout) + val results = new ConcurrentLinkedQueue[T]() + val waitingFor = new AtomicInteger(futures.size) + + val aggregate = (f: Future[T]) => if (!result.isCompleted) { //TODO: This is an optimization, is it premature? + if (f.exception.isDefined) + result completeWithException f.exception.get + else { + results add f.result.get + if (waitingFor.decrementAndGet == 0) { //Only one thread can get here + try { + val r = scala.collection.JavaConversions.asScalaIterable(results).foldLeft(zero)(foldFun) + results.clear //Do not retain the values since someone can hold onto the Future for a long time + result completeWithResult r + } catch { + case e: Exception => result completeWithException e + } + } + } + } + + futures foreach { _ onComplete aggregate } + result + } } sealed trait Future[T] { @@ -199,37 +227,41 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] { } def completeWithResult(result: T) { - val notify = try { + val notifyTheseListeners = try { _lock.lock if (!_completed) { _completed = true _result = Some(result) - true - } else false + val all = _listeners + _listeners = Nil + all + } else Nil } finally { _signal.signalAll _lock.unlock } - if (notify) - notifyListeners + if (notifyTheseListeners.nonEmpty) + notifyTheseListeners foreach notify } def completeWithException(exception: Throwable) { - val notify = try { + val notifyTheseListeners = try { _lock.lock if (!_completed) { _completed = true _exception = Some(exception) - true - } else false + val all = _listeners + _listeners = Nil + all + } else Nil } finally { _signal.signalAll _lock.unlock } - if (notify) - notifyListeners + if (notifyTheseListeners.nonEmpty) + notifyTheseListeners foreach notify } def onComplete(func: Future[T] => Unit): CompletableFuture[T] = { @@ -246,17 +278,12 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] { } if (notifyNow) - notifyListener(func) + notify(func) this } - private def notifyListeners() { - for(l <- _listeners) - notifyListener(l) - } - - private def notifyListener(func: Future[T] => Unit) { + private def notify(func: Future[T] => Unit) { func(this) }