diff --git a/akka-remote/src/main/scala/akka/remote/transport/ThrottlerTransportAdapter.scala b/akka-remote/src/main/scala/akka/remote/transport/ThrottlerTransportAdapter.scala index 55dfc152da..f6c7c59e29 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/ThrottlerTransportAdapter.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/ThrottlerTransportAdapter.scala @@ -4,15 +4,14 @@ package akka.remote.transport import akka.actor._ -import akka.pattern.ask -import akka.pattern.pipe +import akka.pattern.{ PromiseActorRef, ask, pipe } import akka.remote.transport.ActorTransportAdapter.AssociateUnderlying import akka.remote.transport.AkkaPduCodec.Associate import akka.remote.transport.AssociationHandle.{ ActorHandleEventListener, Disassociated, InboundPayload, HandleEventListener } import akka.remote.transport.ThrottlerManager.Checkin import akka.remote.transport.ThrottlerTransportAdapter._ import akka.remote.transport.Transport._ -import akka.util.ByteString +import akka.util.{ Timeout, ByteString } import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference import scala.annotation.tailrec @@ -23,6 +22,7 @@ import scala.math.min import scala.util.{ Success, Failure } import scala.util.control.NonFatal import scala.concurrent.duration._ +import akka.dispatch.{ Unwatch, Watch } class ThrottlerProvider extends TransportAdapterProvider { @@ -271,8 +271,8 @@ private[transport] class ThrottlerManager(wrappedTransport: Transport) extends A import ActorTransportAdapter.AskTimeout if (direction.includes(Direction.Send)) handle.outboundThrottleMode.set(mode) - if (direction.includes(Direction.Receive) && !handle.throttlerActor.isTerminated) - (handle.throttlerActor ? mode).mapTo[SetThrottleAck.type] + if (direction.includes(Direction.Receive)) + askWithDeathCompletion(handle.throttlerActor, mode, SetThrottleAck).mapTo[SetThrottleAck.type] else Future.successful(SetThrottleAck) } @@ -284,6 +284,24 @@ private[transport] class ThrottlerManager(wrappedTransport: Transport) extends A ThrottlerHandle(originalHandle, throttlerActor) } + private def askWithDeathCompletion(target: ActorRef, question: Any, answer: Any)(implicit timeout: Timeout): Future[Any] = { + if (target.isTerminated) Future successful answer + else { + val internalTarget = target.asInstanceOf[InternalActorRef] + val promiseActorRef = PromiseActorRef(context.system.asInstanceOf[ExtendedActorSystem].provider, timeout) + internalTarget.sendSystemMessage(Watch(target, promiseActorRef)) + val future = promiseActorRef.result.future + future onComplete { // remember to unwatch if termination didn't complete + case Success(Terminated(`target`)) ⇒ () + case _ ⇒ internalTarget.sendSystemMessage(Unwatch(target, promiseActorRef)) + } + target.tell(question, promiseActorRef) + future map { + case Terminated(`target`) ⇒ answer + case x ⇒ x + } + } + } } /**