From 249f14191d56118e6f68f1158c895f4b15dcf6c9 Mon Sep 17 00:00:00 2001 From: Viktor Klang Date: Fri, 12 Nov 2010 12:11:53 +0100 Subject: [PATCH] Adding support for onComplete listeners to Future --- .../src/main/scala/dispatch/Future.scala | 83 ++++++++++++++----- .../src/main/scala/remote/RemoteServer.scala | 69 ++++++++------- 2 files changed, 99 insertions(+), 53 deletions(-) diff --git a/akka-actor/src/main/scala/dispatch/Future.scala b/akka-actor/src/main/scala/dispatch/Future.scala index 68a3ce4399..cfdf0be34b 100644 --- a/akka-actor/src/main/scala/dispatch/Future.scala +++ b/akka-actor/src/main/scala/dispatch/Future.scala @@ -10,6 +10,7 @@ import akka.routing.Dispatcher import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.TimeUnit +import akka.japi.Procedure class FutureTimeoutException(message: String) extends AkkaException(message) @@ -100,6 +101,11 @@ sealed trait Future[T] { def exception: Option[Throwable] + def onComplete(func: Future[T] => Unit): Future[T] + + /* Java API */ + def onComplete(proc: Procedure[Future[T]]): Future[T] = onComplete(f => proc(f)) + def map[O](f: (T) => O): Future[O] = { val wrapped = this new Future[O] { @@ -110,13 +116,13 @@ sealed trait Future[T] { def timeoutInNanos = wrapped.timeoutInNanos def result: Option[O] = { wrapped.result map f } def exception: Option[Throwable] = wrapped.exception + def onComplete(func: Future[O] => Unit): Future[O] = { wrapped.onComplete(_ => func(this)); this } } } } trait CompletableFuture[T] extends Future[T] { def completeWithResult(result: T) - def completeWithException(exception: Throwable) } @@ -133,6 +139,7 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] { private var _completed: Boolean = _ private var _result: Option[T] = None private var _exception: Option[Throwable] = None + private var _listeners: List[Future[T] => Unit] = Nil def await = try { _lock.lock @@ -190,33 +197,67 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] { _lock.unlock } - def completeWithResult(result: T) = try { - _lock.lock - if (!_completed) { - _completed = true - _result = Some(result) - onComplete(result) + def completeWithResult(result: T) { + val notify = try { + _lock.lock + if (!_completed) { + _completed = true + _result = Some(result) + true + } else false + } finally { + _signal.signalAll + _lock.unlock } - } finally { - _signal.signalAll - _lock.unlock + + if (notify) + notifyListeners } - def completeWithException(exception: Throwable) = try { - _lock.lock - if (!_completed) { - _completed = true - _exception = Some(exception) - onCompleteException(exception) + def completeWithException(exception: Throwable) { + val notify = try { + _lock.lock + if (!_completed) { + _completed = true + _exception = Some(exception) + true + } else false + } finally { + _signal.signalAll + _lock.unlock } - } finally { - _signal.signalAll - _lock.unlock + + if (notify) + notifyListeners } - protected def onComplete(result: T) {} + def onComplete(func: Future[T] => Unit): CompletableFuture[T] = { + val notifyNow = try { + _lock.lock + if (!_completed) { + _listeners ::= func + false + } + else + true + } finally { + _lock.unlock + } - protected def onCompleteException(exception: Throwable) {} + if (notifyNow) + notifyListener(func) + + this + } + + private def notifyListeners() { + for(l <- _listeners) + notifyListener(l) + } + + private def notifyListener(func: Future[T] => Unit) { + func(this) + } private def currentTimeInNanos: Long = TIME_UNIT.toNanos(System.currentTimeMillis) } diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala index 62179ed11c..bc12477970 100644 --- a/akka-remote/src/main/scala/remote/RemoteServer.scala +++ b/akka-remote/src/main/scala/remote/RemoteServer.scala @@ -497,7 +497,7 @@ class RemoteServerHandler( } private def handleRemoteMessageProtocol(request: RemoteMessageProtocol, channel: Channel) = { - log.debug("Received RemoteMessageProtocol[\n%s]", request.toString) + log.debug("Received RemoteMessageProtocol[\n%s]".format(request.toString)) request.getActorInfo.getActorType match { case SCALA_ACTOR => dispatchToActor(request, channel) case TYPED_ACTOR => dispatchToTypedActor(request, channel) @@ -538,41 +538,46 @@ class RemoteServerHandler( message, request.getActorInfo.getTimeout, None, - Some(new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout){ - override def onComplete(result: AnyRef) { - log.debug("Returning result from actor invocation [%s]", result) - val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( - Some(actorRef), - Right(request.getUuid), - actorInfo.getId, - actorInfo.getTarget, - actorInfo.getTimeout, - Left(result), - true, - Some(actorRef), - None, - AkkaActorType.ScalaActor, - None) + Some(new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout). + onComplete(f => { + val result = f.result + val exception = f.exception - // FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method - if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) + if (exception.isDefined) { + log.debug("Returning exception from actor invocation [%s]".format(exception.get)) + try { + channel.write(createErrorReplyMessage(exception.get, request, AkkaActorType.ScalaActor)) + } catch { + case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) + } + } + else if (result.isDefined) { + log.debug("Returning result from actor invocation [%s]".format(result.get)) + val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( + Some(actorRef), + Right(request.getUuid), + actorInfo.getId, + actorInfo.getTarget, + actorInfo.getTimeout, + Left(result.get), + true, + Some(actorRef), + None, + AkkaActorType.ScalaActor, + None) - try { - channel.write(messageBuilder.build) - } catch { - case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) + // FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method + if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) + + try { + channel.write(messageBuilder.build) + } catch { + case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) + } } } - - override def onCompleteException(exception: Throwable) { - try { - channel.write(createErrorReplyMessage(exception, request, AkkaActorType.ScalaActor)) - } catch { - case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) - } - } - } - )) + ) + )) } }