diff --git a/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala b/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala index e12294a70d..34edbb653f 100644 --- a/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/dispatch/FutureSpec.scala @@ -351,34 +351,93 @@ class FutureSpec extends JUnitSuite { val latch = new StandardLatch val f = Future({ latch.await; 5}) - val f2 = Future({ f() + 5 }) + val f2 = Future({ f.get + 5 }) assert(f2.resultOrException === None) latch.open - assert(f2() === 10) + assert(f2.get === 10) val f3 = Future({ Thread.sleep(100); 5}, 10) intercept[FutureTimeoutException] { - f3() + f3.get } } - @Test def lesslessIsMore { - import akka.actor.Actor.spawn - val dataflowVar, dataflowVar2 = new DefaultCompletableFuture[Int](Long.MaxValue) - val begin, end = new StandardLatch - spawn { - begin.await - dataflowVar2 << dataflowVar - end.open + @Test def futureComposingWithContinuations { + import Future.flow + + val actor = actorOf[TestActor].start + + val x = Future("Hello") + val y = x flatMap (actor !!! _) + + val r = flow(x() + " " + y[String]() + "!") + + assert(r.get === "Hello World!") + + actor.stop + } + + @Test def futureComposingWithContinuationsFailureDivideZero { + import Future.flow + + val x = Future("Hello") + val y = x map (_.length) + + val r = flow(x() + " " + y.map(_ / 0).map(_.toString)(), 100) + + intercept[java.lang.ArithmeticException](r.get) + } + + @Test def futureComposingWithContinuationsFailureCastInt { + import Future.flow + + val actor = actorOf[TestActor].start + + val x = Future(3) + val y = actor !!! "Hello" + + val r = flow(x() + y[Int](), 100) + + intercept[ClassCastException](r.get) + } + + @Test def futureComposingWithContinuationsFailureCastNothing { + import Future.flow + + val actor = actorOf[TestActor].start + + val x = Future("Hello") + val y = actor !!! "Hello" + + val r = flow(x() + y()) + + intercept[ClassCastException](r.get) + } + + @Test def futureCompletingWithContinuations { + import Future.flow + + val x, y, z = new DefaultCompletableFuture[Int](Actor.TIMEOUT) + val ly, lz = new StandardLatch + + val result = flow { + y completeWith x + ly.open // not within continuation + + z << x + lz.open // within continuation, will wait for 'z' to complete + z() + y() } - spawn { - dataflowVar << 5 - } - begin.open - end.await - assert(dataflowVar2() === 5) - assert(dataflowVar.get === 5) + assert(ly.tryAwaitUninterruptible(100, TimeUnit.MILLISECONDS)) + assert(!lz.tryAwaitUninterruptible(100, TimeUnit.MILLISECONDS)) + + x << 5 + + assert(y.get === 5) + assert(z.get === 5) + assert(lz.isOpen) + assert(result.get === 10) } } diff --git a/akka-actor/src/main/scala/akka/dispatch/Future.scala b/akka-actor/src/main/scala/akka/dispatch/Future.scala index c69ca82bad..f66d1ab25b 100644 --- a/akka-actor/src/main/scala/akka/dispatch/Future.scala +++ b/akka-actor/src/main/scala/akka/dispatch/Future.scala @@ -10,6 +10,8 @@ import akka.actor.Actor import akka.routing.Dispatcher import akka.japi.{ Procedure, Function => JFunc } +import scala.util.continuations._ + import java.util.concurrent.locks.ReentrantLock import java.util.concurrent. {ConcurrentLinkedQueue, TimeUnit, Callable} import java.util.concurrent.TimeUnit.{NANOSECONDS => NANOS, MILLISECONDS => MILLIS} @@ -261,22 +263,30 @@ object Future { val fb = fn(a.asInstanceOf[A]) for (r <- fr; b <-fb) yield (r += b) }.map(_.result) + + def flow[A](body: => A @cpsParam[Future[Any],Future[Any]], timeout: Long = Actor.TIMEOUT): Future[A] = { + + val future = new DefaultCompletableFuture[A](timeout) + + reset(future completeWithResult body) onComplete { f => + val ex = f.exception + if (ex.isDefined) future.completeWithException(ex.get) + } + + future + } } sealed trait Future[+T] { - /** - * Returns the result of this future after waiting for it to complete, - * this method will throw any throwable that this Future was completed with - * and will throw a java.util.concurrent.TimeoutException if there is no result - * within the Futures timeout - */ - def apply(): T = this.await.resultOrException.get + def apply[A >: T](): A @cpsParam[Future[Any],Future[Any]] = shift { f: (A => Future[Any]) => + (new DefaultCompletableFuture[Any](timeoutInNanos, NANOS)) completeWith (this flatMap f) + } /** * Java API for apply() */ - def get: T = apply() + def get: T = this.await.resultOrException.get /** * Blocks the current thread until the Future has been completed or the @@ -581,10 +591,10 @@ trait CompletableFuture[T] extends Future[T] { */ final def << (value: T): Future[T] = complete(Right(value)) - /** - * Alias for completeWith(other). - */ - final def << (other : Future[T]): Future[T] = completeWith(other) + final def << (other: Future[T]): T @cpsParam[Future[Any],Future[Any]] = shift { k: (T => Future[Any]) => + this completeWith other flatMap k + } + } /** diff --git a/project/build/AkkaProject.scala b/project/build/AkkaProject.scala index 8073d8831b..db1f97ae80 100644 --- a/project/build/AkkaProject.scala +++ b/project/build/AkkaProject.scala @@ -10,7 +10,7 @@ import sbt._ import sbt.CompileOrder._ import spde._ -class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) { +class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) with AutoCompilerPlugins { // ------------------------------------------------------------------------------------------------------------------- // Compile settings @@ -273,8 +273,10 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) { // akka-actor subproject // ------------------------------------------------------------------------------------------------------------------- - class AkkaActorProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) with OsgiProject { + class AkkaActorProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) with OsgiProject with AutoCompilerPlugins { override def bndExportPackage = super.bndExportPackage ++ Seq("com.eaio.*;version=3.2") + val cont = compilerPlugin("org.scala-lang.plugins" % "continuations" % "2.9.0.RC1") + override def compileOptions = super.compileOptions ++ compileOptions("-P:continuations:enable") } // ------------------------------------------------------------------------------------------------------------------- @@ -436,11 +438,13 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) { // akka-actor-tests subproject // ------------------------------------------------------------------------------------------------------------------- - class AkkaActorTestsProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) { + class AkkaActorTestsProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) with AutoCompilerPlugins { // testing val junit = Dependencies.junit val scalatest = Dependencies.scalatest val multiverse_test = Dependencies.multiverse_test // StandardLatch + val cont = compilerPlugin("org.scala-lang.plugins" % "continuations" % "2.9.0.RC1") + override def compileOptions = super.compileOptions ++ compileOptions("-P:continuations:enable") } // -------------------------------------------------------------------------------------------------------------------