Adding support for onComplete listeners to Future

This commit is contained in:
Viktor Klang 2010-11-12 12:11:53 +01:00
parent 9df923dd16
commit 249f14191d
2 changed files with 99 additions and 53 deletions

View file

@ -10,6 +10,7 @@ import akka.routing.Dispatcher
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import akka.japi.Procedure
class FutureTimeoutException(message: String) extends AkkaException(message) class FutureTimeoutException(message: String) extends AkkaException(message)
@ -100,6 +101,11 @@ sealed trait Future[T] {
def exception: Option[Throwable] def exception: Option[Throwable]
def onComplete(func: Future[T] => Unit): Future[T]
/* Java API */
def onComplete(proc: Procedure[Future[T]]): Future[T] = onComplete(f => proc(f))
def map[O](f: (T) => O): Future[O] = { def map[O](f: (T) => O): Future[O] = {
val wrapped = this val wrapped = this
new Future[O] { new Future[O] {
@ -110,13 +116,13 @@ sealed trait Future[T] {
def timeoutInNanos = wrapped.timeoutInNanos def timeoutInNanos = wrapped.timeoutInNanos
def result: Option[O] = { wrapped.result map f } def result: Option[O] = { wrapped.result map f }
def exception: Option[Throwable] = wrapped.exception def exception: Option[Throwable] = wrapped.exception
def onComplete(func: Future[O] => Unit): Future[O] = { wrapped.onComplete(_ => func(this)); this }
} }
} }
} }
trait CompletableFuture[T] extends Future[T] { trait CompletableFuture[T] extends Future[T] {
def completeWithResult(result: T) def completeWithResult(result: T)
def completeWithException(exception: Throwable) def completeWithException(exception: Throwable)
} }
@ -133,6 +139,7 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] {
private var _completed: Boolean = _ private var _completed: Boolean = _
private var _result: Option[T] = None private var _result: Option[T] = None
private var _exception: Option[Throwable] = None private var _exception: Option[Throwable] = None
private var _listeners: List[Future[T] => Unit] = Nil
def await = try { def await = try {
_lock.lock _lock.lock
@ -190,33 +197,67 @@ class DefaultCompletableFuture[T](timeout: Long) extends CompletableFuture[T] {
_lock.unlock _lock.unlock
} }
def completeWithResult(result: T) = try { def completeWithResult(result: T) {
_lock.lock val notify = try {
if (!_completed) { _lock.lock
_completed = true if (!_completed) {
_result = Some(result) _completed = true
onComplete(result) _result = Some(result)
true
} else false
} finally {
_signal.signalAll
_lock.unlock
} }
} finally {
_signal.signalAll if (notify)
_lock.unlock notifyListeners
} }
def completeWithException(exception: Throwable) = try { def completeWithException(exception: Throwable) {
_lock.lock val notify = try {
if (!_completed) { _lock.lock
_completed = true if (!_completed) {
_exception = Some(exception) _completed = true
onCompleteException(exception) _exception = Some(exception)
true
} else false
} finally {
_signal.signalAll
_lock.unlock
} }
} finally {
_signal.signalAll if (notify)
_lock.unlock notifyListeners
} }
protected def onComplete(result: T) {} def onComplete(func: Future[T] => Unit): CompletableFuture[T] = {
val notifyNow = try {
_lock.lock
if (!_completed) {
_listeners ::= func
false
}
else
true
} finally {
_lock.unlock
}
protected def onCompleteException(exception: Throwable) {} if (notifyNow)
notifyListener(func)
this
}
private def notifyListeners() {
for(l <- _listeners)
notifyListener(l)
}
private def notifyListener(func: Future[T] => Unit) {
func(this)
}
private def currentTimeInNanos: Long = TIME_UNIT.toNanos(System.currentTimeMillis) private def currentTimeInNanos: Long = TIME_UNIT.toNanos(System.currentTimeMillis)
} }

View file

@ -497,7 +497,7 @@ class RemoteServerHandler(
} }
private def handleRemoteMessageProtocol(request: RemoteMessageProtocol, channel: Channel) = { private def handleRemoteMessageProtocol(request: RemoteMessageProtocol, channel: Channel) = {
log.debug("Received RemoteMessageProtocol[\n%s]", request.toString) log.debug("Received RemoteMessageProtocol[\n%s]".format(request.toString))
request.getActorInfo.getActorType match { request.getActorInfo.getActorType match {
case SCALA_ACTOR => dispatchToActor(request, channel) case SCALA_ACTOR => dispatchToActor(request, channel)
case TYPED_ACTOR => dispatchToTypedActor(request, channel) case TYPED_ACTOR => dispatchToTypedActor(request, channel)
@ -538,41 +538,46 @@ class RemoteServerHandler(
message, message,
request.getActorInfo.getTimeout, request.getActorInfo.getTimeout,
None, None,
Some(new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout){ Some(new DefaultCompletableFuture[AnyRef](request.getActorInfo.getTimeout).
override def onComplete(result: AnyRef) { onComplete(f => {
log.debug("Returning result from actor invocation [%s]", result) val result = f.result
val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder( val exception = f.exception
Some(actorRef),
Right(request.getUuid),
actorInfo.getId,
actorInfo.getTarget,
actorInfo.getTimeout,
Left(result),
true,
Some(actorRef),
None,
AkkaActorType.ScalaActor,
None)
// FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method if (exception.isDefined) {
if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid) log.debug("Returning exception from actor invocation [%s]".format(exception.get))
try {
channel.write(createErrorReplyMessage(exception.get, request, AkkaActorType.ScalaActor))
} catch {
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
else if (result.isDefined) {
log.debug("Returning result from actor invocation [%s]".format(result.get))
val messageBuilder = RemoteActorSerialization.createRemoteMessageProtocolBuilder(
Some(actorRef),
Right(request.getUuid),
actorInfo.getId,
actorInfo.getTarget,
actorInfo.getTimeout,
Left(result.get),
true,
Some(actorRef),
None,
AkkaActorType.ScalaActor,
None)
try { // FIXME lift in the supervisor uuid management into toh createRemoteMessageProtocolBuilder method
channel.write(messageBuilder.build) if (request.hasSupervisorUuid) messageBuilder.setSupervisorUuid(request.getSupervisorUuid)
} catch {
case e: Throwable => server.notifyListeners(RemoteServerError(e, server)) try {
channel.write(messageBuilder.build)
} catch {
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
} }
} }
)
override def onCompleteException(exception: Throwable) { ))
try {
channel.write(createErrorReplyMessage(exception, request, AkkaActorType.ScalaActor))
} catch {
case e: Throwable => server.notifyListeners(RemoteServerError(e, server))
}
}
}
))
} }
} }