=rem #15109: Separate field in Pass for refuseUid

- Fixes #15109
  - also fix GotUid race with InvalidAddress supervision event
(cherry picked from commit 3fe83fa)
This commit is contained in:
Endre Sándor Varga 2014-05-12 13:51:05 +02:00
parent 70b0fc5ab5
commit 0e46db47d9
6 changed files with 180 additions and 44 deletions

View file

@ -164,7 +164,7 @@ private[remote] class OversizedPayloadException(msg: String) extends EndpointExc
private[remote] object ReliableDeliverySupervisor {
case object Ungate
case object AttemptSysMsgRedelivery
final case class GotUid(uid: Int)
final case class GotUid(uid: Int, remoteAddres: Address)
def props(
handleOrActive: Option[AkkaProtocolHandle],
@ -309,7 +309,7 @@ private[remote] class ReliableDeliverySupervisor(
if (resendBuffer.nonAcked.nonEmpty || resendBuffer.nacked.nonEmpty)
context.system.scheduler.scheduleOnce(settings.SysResendTimeout, self, AttemptSysMsgRedelivery)
context.become(idle)
case g @ GotUid(receivedUid)
case g @ GotUid(receivedUid, _)
context.parent ! g
// New system that has the same address as the old - need to start from fresh state
uidConfirmed = true
@ -574,7 +574,7 @@ private[remote] class EndpointWriter(
publishAndThrow(new EndpointAssociationException(s"Association failed with [$remoteAddress]", e), Logging.DebugLevel)
case Handle(inboundHandle)
// Assert handle == None?
context.parent ! ReliableDeliverySupervisor.GotUid(inboundHandle.handshakeInfo.uid)
context.parent ! ReliableDeliverySupervisor.GotUid(inboundHandle.handshakeInfo.uid, remoteAddress)
handle = Some(inboundHandle)
reader = startReadEndpoint(inboundHandle)
eventPublisher.notifyListeners(AssociatedEvent(localAddress, remoteAddress, inbound))

View file

@ -273,7 +273,7 @@ private[remote] object EndpointManager {
*/
def isTombstone: Boolean
}
final case class Pass(endpoint: ActorRef, uid: Option[Int]) extends EndpointPolicy {
final case class Pass(endpoint: ActorRef, uid: Option[Int], refuseUid: Option[Int]) extends EndpointPolicy {
override def isTombstone: Boolean = false
}
final case class Gated(timeOfRelease: Deadline) extends EndpointPolicy {
@ -290,20 +290,20 @@ private[remote] object EndpointManager {
private var addressToReadonly = HashMap[Address, ActorRef]()
private var readonlyToAddress = HashMap[ActorRef, Address]()
def registerWritableEndpoint(address: Address, uid: Option[Int], endpoint: ActorRef): ActorRef = addressToWritable.get(address) match {
case Some(Pass(e, _))
throw new IllegalArgumentException(s"Attempting to overwrite existing endpoint [$e] with [$endpoint]")
case _
addressToWritable += address -> Pass(endpoint, uid)
writableToAddress += endpoint -> address
endpoint
}
def registerWritableEndpointUid(writer: ActorRef, uid: Int): Unit = {
val address = writableToAddress(writer)
def registerWritableEndpoint(address: Address, uid: Option[Int], refuseUid: Option[Int], endpoint: ActorRef): ActorRef =
addressToWritable.get(address) match {
case Some(Pass(ep, _)) addressToWritable += address -> Pass(ep, Some(uid))
case other // the GotUid might have lost the race with some failure
case Some(Pass(e, _, _))
throw new IllegalArgumentException(s"Attempting to overwrite existing endpoint [$e] with [$endpoint]")
case _
addressToWritable += address -> Pass(endpoint, uid, refuseUid)
writableToAddress += endpoint -> address
endpoint
}
def registerWritableEndpointUid(remoteAddress: Address, uid: Int): Unit = {
addressToWritable.get(remoteAddress) match {
case Some(Pass(ep, _, refuseUid)) addressToWritable += remoteAddress -> Pass(ep, Some(uid), refuseUid)
case other // the GotUid might have lost the race with some failure
}
}
@ -329,8 +329,8 @@ private[remote] object EndpointManager {
def writableEndpointWithPolicyFor(address: Address): Option[EndpointPolicy] = addressToWritable.get(address)
def hasWritableEndpointFor(address: Address): Boolean = writableEndpointWithPolicyFor(address) match {
case Some(Pass(_, _)) true
case _ false
case Some(Pass(_, _, _)) true
case _ false
}
def readOnlyEndpointFor(address: Address): Option[ActorRef] = addressToReadonly.get(address)
@ -349,9 +349,9 @@ private[remote] object EndpointManager {
def refuseUid(address: Address): Option[Int] = writableEndpointWithPolicyFor(address) match {
// timeOfRelease is only used for garbage collection. If an address is still probed, we should report the
// known fact that it is quarantined.
case Some(Quarantined(uid, _)) Some(uid)
case Some(Pass(_, uidOption)) uidOption
case _ None
case Some(Quarantined(uid, _)) Some(uid)
case Some(Pass(_, _, refuseUid)) refuseUid
case _ None
}
/**
@ -526,7 +526,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
case Quarantine(address, uidOption)
// Stop writers
endpoints.writableEndpointWithPolicyFor(address) match {
case Some(Pass(endpoint, _))
case Some(Pass(endpoint, _, _))
context.stop(endpoint)
if (uidOption.isEmpty) {
log.warning("Association to [{}] with unknown UID is reported as quarantined, but " +
@ -552,6 +552,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
def createAndRegisterWritingEndpoint(refuseUid: Option[Int]): ActorRef =
endpoints.registerWritableEndpoint(
recipientAddress,
uid = None,
refuseUid,
createEndpoint(
recipientAddress,
@ -563,7 +564,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
refuseUid))
endpoints.writableEndpointWithPolicyFor(recipientAddress) match {
case Some(Pass(endpoint, _))
case Some(Pass(endpoint, _, _))
endpoint ! s
case Some(Gated(timeOfRelease))
if (timeOfRelease.isOverdue()) createAndRegisterWritingEndpoint(refuseUid = None) ! s
@ -587,8 +588,8 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
handleStashedInbound(endpoint)
case EndpointWriter.TookOver(endpoint, handle)
removePendingReader(takingOverFrom = endpoint, withHandle = handle)
case ReliableDeliverySupervisor.GotUid(uid)
endpoints.registerWritableEndpointUid(sender, uid)
case ReliableDeliverySupervisor.GotUid(uid, remoteAddress)
endpoints.registerWritableEndpointUid(remoteAddress, uid)
handleStashedInbound(sender)
case Prune
endpoints.prune()
@ -630,9 +631,9 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
if (endpoints.isQuarantined(handle.remoteAddress, handle.handshakeInfo.uid))
handle.disassociate(AssociationHandle.Quarantined)
else endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match {
case Some(Pass(ep, None))
case Some(Pass(ep, None, _))
stashedInbound += ep -> (stashedInbound.getOrElse(ep, Vector.empty) :+ ia)
case Some(Pass(ep, Some(uid)))
case Some(Pass(ep, Some(uid), _))
if (handle.handshakeInfo.uid == uid) {
pendingReadHandoffs.get(ep) foreach (_.disassociate())
pendingReadHandoffs += ep -> handle
@ -641,10 +642,10 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
context.stop(ep)
endpoints.unregisterEndpoint(ep)
pendingReadHandoffs -= ep
createAndRegisterEndpoint(handle, Some(uid))
createAndRegisterEndpoint(handle, refuseUid = Some(uid))
}
case state
createAndRegisterEndpoint(handle, None)
createAndRegisterEndpoint(handle, refuseUid = endpoints.refuseUid(handle.remoteAddress))
}
}
}
@ -661,7 +662,7 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
writing,
refuseUid = refuseUid)
if (writing)
endpoints.registerWritableEndpoint(handle.remoteAddress, Some(handle.handshakeInfo.uid), endpoint)
endpoints.registerWritableEndpoint(handle.remoteAddress, Some(handle.handshakeInfo.uid), refuseUid, endpoint)
else {
endpoints.registerReadOnlyEndpoint(handle.remoteAddress, endpoint)
endpoints.removePolicy(handle.remoteAddress)

View file

@ -7,7 +7,7 @@ import akka.actor._
import akka.pattern.{ PromiseActorRef, ask, pipe }
import akka.remote.transport.ActorTransportAdapter.AssociateUnderlying
import akka.remote.transport.AkkaPduCodec.Associate
import akka.remote.transport.AssociationHandle.{ ActorHandleEventListener, Disassociated, InboundPayload, HandleEventListener }
import akka.remote.transport.AssociationHandle.{ DisassociateInfo, ActorHandleEventListener, Disassociated, InboundPayload, HandleEventListener }
import akka.remote.transport.ThrottlerManager.{ Listener, Handle, ListenerAndMode, Checkin }
import akka.remote.transport.ThrottlerTransportAdapter._
import akka.remote.transport.Transport._
@ -150,6 +150,12 @@ object ThrottlerTransportAdapter {
@SerialVersionUID(1L)
final case class ForceDisassociate(address: Address)
/**
* Management Command to force dissocation of an address with an explicit error.
*/
@SerialVersionUID(1L)
final case class ForceDisassociateExplicitly(address: Address, reason: DisassociateInfo)
@SerialVersionUID(1L)
case object ForceDisassociateAck {
/**
@ -172,9 +178,10 @@ class ThrottlerTransportAdapter(_wrappedTransport: Transport, _system: ExtendedA
override def managementCommand(cmd: Any): Future[Boolean] = {
import ActorTransportAdapter.AskTimeout
cmd match {
case s: SetThrottle manager ? s map { case SetThrottleAck true }
case f: ForceDisassociate manager ? f map { case ForceDisassociateAck true }
case _ wrappedTransport.managementCommand(cmd)
case s: SetThrottle manager ? s map { case SetThrottleAck true }
case f: ForceDisassociate manager ? f map { case ForceDisassociateAck true }
case f: ForceDisassociateExplicitly manager ? f map { case ForceDisassociateAck true }
case _ wrappedTransport.managementCommand(cmd)
}
}
}
@ -242,6 +249,13 @@ private[transport] class ThrottlerManager(wrappedTransport: Transport) extends A
case _
}
sender() ! ForceDisassociateAck
case ForceDisassociateExplicitly(address, reason)
val naked = nakedAddress(address)
handleTable foreach {
case (`naked`, handle) handle.disassociateWithFailure(reason)
case _
}
sender() ! ForceDisassociateAck
case Checkin(origin, handle)
val naked: Address = nakedAddress(origin)
@ -338,6 +352,8 @@ private[transport] object ThrottledAssociation {
sealed trait ThrottlerData
case object Uninitialized extends ThrottlerData
final case class ExposedHandle(handle: ThrottlerHandle) extends ThrottlerData
final case class FailWith(reason: DisassociateInfo)
}
/**
@ -454,6 +470,9 @@ private[transport] class ThrottledAssociation(
stay()
case Event(Disassociated(info), _)
stop() // not notifying the upstream handler is intentional: we are relying on heartbeating
case Event(FailWith(reason), _)
upstreamListener notify Disassociated(reason)
stop()
}
// This method captures ASSOCIATE packets and extracts the origin address
@ -534,4 +553,8 @@ private[transport] final case class ThrottlerHandle(_wrappedHandle: AssociationH
throttlerActor ! PoisonPill
}
def disassociateWithFailure(reason: DisassociateInfo): Unit = {
throttlerActor ! ThrottledAssociation.FailWith(reason)
}
}