Improve performance of DDataShard stashing, #26877

* While waiting for update to comple it will now deliver messages to other
  already started entities immediately, instead of stashing
* Unstash one message at a time, instead of unstashAll
* Append messageBuffer for messages to the entity that we are waiting for,
  instead of stashing
* Test to confirm the improvements
* Fixing a few other missing things
  * receiveStartEntity should process the change before starting the entity
  * lastMessageTimestamp should be touched from overridden deliverTo
  * handle StoreFailure
This commit is contained in:
Patrik Nordwall 2019-05-07 08:50:32 +02:00
parent 35e7e07488
commit ce438637bb
4 changed files with 435 additions and 80 deletions

View file

@ -6,34 +6,34 @@ package akka.cluster.sharding
import java.net.URLEncoder
import akka.actor.{
Actor,
ActorLogging,
ActorRef,
ActorSystem,
DeadLetterSuppression,
Deploy,
NoSerializationVerificationNeeded,
Props,
Stash,
Terminated,
Timers
}
import akka.util.{ ConstantFun, MessageBufferMap }
import scala.concurrent.Future
import scala.concurrent.duration._
import akka.actor.Actor
import akka.actor.ActorLogging
import akka.actor.ActorRef
import akka.actor.ActorSystem
import akka.actor.DeadLetterSuppression
import akka.actor.Deploy
import akka.actor.NoSerializationVerificationNeeded
import akka.actor.Props
import akka.actor.Stash
import akka.actor.Terminated
import akka.actor.Timers
import akka.cluster.Cluster
import akka.cluster.ddata.ORSet
import akka.cluster.ddata.ORSetKey
import akka.cluster.ddata.Replicator._
import akka.cluster.ddata.SelfUniqueAddress
import akka.persistence._
import akka.util.PrettyDuration._
import akka.coordination.lease.scaladsl.{ Lease, LeaseProvider }
import akka.pattern.pipe
import scala.concurrent.duration._
import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage
import akka.cluster.sharding.ShardRegion.ShardInitialized
import akka.cluster.sharding.ShardRegion.ShardRegionCommand
import akka.coordination.lease.scaladsl.Lease
import akka.coordination.lease.scaladsl.LeaseProvider
import akka.pattern.pipe
import akka.persistence._
import akka.util.MessageBufferMap
import akka.util.PrettyDuration._
import akka.util.unused
/**
@ -150,7 +150,7 @@ private[akka] object Shard {
.withDeploy(Deploy.local)
}
private case object PassivateIdleTick extends NoSerializationVerificationNeeded
case object PassivateIdleTick extends NoSerializationVerificationNeeded
}
@ -174,9 +174,14 @@ private[akka] class Shard(
with ActorLogging
with Timers {
import ShardRegion.{ handOffStopperProps, EntityId, Msg, Passivate, ShardInitialized }
import ShardCoordinator.Internal.{ HandOff, ShardStopped }
import Shard._
import ShardCoordinator.Internal.HandOff
import ShardCoordinator.Internal.ShardStopped
import ShardRegion.EntityId
import ShardRegion.Msg
import ShardRegion.Passivate
import ShardRegion.ShardInitialized
import ShardRegion.handOffStopperProps
import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage
import akka.cluster.sharding.ShardRegion.ShardRegionCommand
import settings.tuningParameters._
@ -188,7 +193,7 @@ private[akka] class Shard(
var passivating = Set.empty[ActorRef]
val messageBuffers = new MessageBufferMap[EntityId]
var handOffStopper: Option[ActorRef] = None
private var handOffStopper: Option[ActorRef] = None
import context.dispatcher
val passivateIdleTask = if (settings.passivateIdleEntityAfter > Duration.Zero) {
@ -198,14 +203,14 @@ private[akka] class Shard(
None
}
val lease = settings.leaseSettings.map(
private val lease = settings.leaseSettings.map(
ls =>
LeaseProvider(context.system).getLease(
s"${context.system.name}-shard-$typeName-$shardId",
ls.leaseImplementation,
Cluster(context.system).selfAddress.hostPort))
val leaseRetryInterval = settings.leaseSettings match {
private val leaseRetryInterval = settings.leaseSettings match {
case Some(l) => l.leaseRetryInterval
case None => 5.seconds // not used
}
@ -249,7 +254,7 @@ private[akka] class Shard(
// Don't send back ShardInitialized so that messages are buffered in the ShardRegion
// while awaiting the lease
def awaitingLease(): Receive = {
private def awaitingLease(): Receive = {
case LeaseAcquireResult(true, _) =>
log.debug("Acquired lease")
onLeaseAcquired()
@ -292,27 +297,32 @@ private[akka] class Shard(
log.error("Shard type [{}] id [{}] lease lost. Reason: {}", typeName, shardId, msg.reason)
// Stop entities ASAP rather than send termination message
context.stop(self)
}
def receiveShardCommand(msg: ShardCommand): Unit = msg match {
private def receiveShardCommand(msg: ShardCommand): Unit = msg match {
// those are only used with remembering entities
case RestartEntity(id) => getOrCreateEntity(id)
case RestartEntities(ids) => restartEntities(ids)
}
def receiveStartEntity(start: ShardRegion.StartEntity): Unit = {
private def receiveStartEntity(start: ShardRegion.StartEntity): Unit = {
val requester = sender()
log.debug("Got a request from [{}] to start entity [{}] in shard [{}]", requester, start.entityId, shardId)
if (passivateIdleTask.isDefined) {
lastMessageTimestamp = lastMessageTimestamp.updated(start.entityId, System.nanoTime())
touchLastMessageTimestamp(start.entityId)
if (state.entities(start.entityId)) {
getOrCreateEntity(start.entityId)
requester ! ShardRegion.StartEntityAck(start.entityId, shardId)
} else {
processChange(EntityStarted(start.entityId)) { evt =>
getOrCreateEntity(start.entityId)
sendMsgBuffer(evt)
requester ! ShardRegion.StartEntityAck(start.entityId, shardId)
}
}
getOrCreateEntity(
start.entityId,
_ =>
processChange(EntityStarted(start.entityId))(_ =>
requester ! ShardRegion.StartEntityAck(start.entityId, shardId)))
}
def receiveStartEntityAck(ack: ShardRegion.StartEntityAck): Unit = {
private def receiveStartEntityAck(ack: ShardRegion.StartEntityAck): Unit = {
if (ack.shardId != shardId && state.entities.contains(ack.entityId)) {
log.debug("Entity [{}] previously owned by shard [{}] started in shard [{}]", ack.entityId, shardId, ack.shardId)
processChange(EntityStopped(ack.entityId)) { _ =>
@ -322,16 +332,16 @@ private[akka] class Shard(
}
}
def restartEntities(ids: Set[EntityId]): Unit = {
private def restartEntities(ids: Set[EntityId]): Unit = {
context.actorOf(RememberEntityStarter.props(context.parent, ids, settings, sender()))
}
def receiveShardRegionCommand(msg: ShardRegionCommand): Unit = msg match {
private def receiveShardRegionCommand(msg: ShardRegionCommand): Unit = msg match {
case Passivate(stopMessage) => passivate(sender(), stopMessage)
case _ => unhandled(msg)
}
def receiveCoordinatorMessage(msg: CoordinatorMessage): Unit = msg match {
private def receiveCoordinatorMessage(msg: CoordinatorMessage): Unit = msg match {
case HandOff(`shardId`) => handOff(sender())
case HandOff(shard) => log.warning("Shard [{}] can not hand off for another Shard [{}]", shardId, shard)
case _ => unhandled(msg)
@ -342,7 +352,7 @@ private[akka] class Shard(
case GetShardStats => sender() ! ShardStats(shardId, state.entities.size)
}
def handOff(replyTo: ActorRef): Unit = handOffStopper match {
private def handOff(replyTo: ActorRef): Unit = handOffStopper match {
case Some(_) => log.warning("HandOff shard [{}] received during existing handOff", shardId)
case None =>
log.debug("HandOff shard [{}]", shardId)
@ -363,7 +373,7 @@ private[akka] class Shard(
}
}
def receiveTerminated(ref: ActorRef): Unit = {
private def receiveTerminated(ref: ActorRef): Unit = {
if (handOffStopper.contains(ref))
context.stop(self)
else if (idByRef.contains(ref) && handOffStopper.isEmpty)
@ -387,7 +397,7 @@ private[akka] class Shard(
passivating = passivating - ref
}
def passivate(entity: ActorRef, stopMessage: Any): Unit = {
private def passivate(entity: ActorRef, stopMessage: Any): Unit = {
idByRef.get(entity) match {
case Some(id) =>
if (!messageBuffers.contains(id)) {
@ -401,7 +411,13 @@ private[akka] class Shard(
}
}
def passivateIdleEntities(): Unit = {
def touchLastMessageTimestamp(id: EntityId): Unit = {
if (passivateIdleTask.isDefined) {
lastMessageTimestamp = lastMessageTimestamp.updated(id, System.nanoTime())
}
}
private def passivateIdleEntities(): Unit = {
val deadline = System.nanoTime() - settings.passivateIdleEntityAfter.toNanos
val refsToPassivate = lastMessageTimestamp.collect {
case (entityId, lastMessageTimestamp) if lastMessageTimestamp < deadline => refById(entityId)
@ -447,29 +463,30 @@ private[akka] class Shard(
// in case it was wrapped, used in Typed
receiveStartEntity(start)
case _ =>
messageBuffers.contains(id) match {
case false => deliverTo(id, msg, payload, snd)
case true if messageBuffers.totalSize >= bufferSize =>
log.debug("Buffer is full, dropping message for entity [{}]", id)
context.system.deadLetters ! msg
case true =>
log.debug("Message for entity [{}] buffered", id)
messageBuffers.append(id, msg, snd)
}
if (messageBuffers.contains(id))
appendToMessageBuffer(id, msg, snd)
else
deliverTo(id, msg, payload, snd)
}
}
}
def deliverTo(id: EntityId, @unused msg: Any, payload: Msg, snd: ActorRef): Unit = {
if (passivateIdleTask.isDefined) {
lastMessageTimestamp = lastMessageTimestamp.updated(id, System.nanoTime())
def appendToMessageBuffer(id: EntityId, msg: Any, snd: ActorRef): Unit = {
if (messageBuffers.totalSize >= bufferSize) {
log.debug("Buffer is full, dropping message for entity [{}]", id)
context.system.deadLetters ! msg
} else {
log.debug("Message for entity [{}] buffered", id)
messageBuffers.append(id, msg, snd)
}
}
def deliverTo(id: EntityId, @unused msg: Any, payload: Msg, snd: ActorRef): Unit = {
touchLastMessageTimestamp(id)
getOrCreateEntity(id).tell(payload, snd)
}
def getOrCreateEntity(id: EntityId, onCreate: ActorRef => Unit = ConstantFun.scalaAnyToUnit): ActorRef = {
def getOrCreateEntity(id: EntityId): ActorRef = {
val name = URLEncoder.encode(id, "utf-8")
context.child(name) match {
case Some(child) => child
@ -478,11 +495,8 @@ private[akka] class Shard(
val a = context.watch(context.actorOf(entityProps(id), name))
idByRef = idByRef.updated(a, id)
refById = refById.updated(id, a)
if (passivateIdleTask.isDefined) {
lastMessageTimestamp += (id -> System.nanoTime())
}
state = state.copy(state.entities + id)
onCreate(a)
touchLastMessageTimestamp(id)
a
}
}
@ -510,8 +524,8 @@ private[akka] class RememberEntityStarter(
extends Actor
with ActorLogging {
import context.dispatcher
import RememberEntityStarter.Tick
import context.dispatcher
var waitingForAck = ids
@ -551,8 +565,9 @@ private[akka] class RememberEntityStarter(
private[akka] trait RememberingShard {
selfType: Shard =>
import ShardRegion.{ EntityId, Msg }
import Shard._
import ShardRegion.EntityId
import ShardRegion.Msg
import akka.pattern.pipe
protected val settings: ClusterShardingSettings
@ -592,6 +607,7 @@ private[akka] trait RememberingShard {
} else {
if (!passivating.contains(ref)) {
log.debug("Entity [{}] stopped without passivating, will restart after backoff", id)
// note that it's not removed from state here, will be started again via RestartEntity
import context.dispatcher
context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id))
} else processChange(EntityStopped(id))(passivateCompleted)
@ -604,9 +620,11 @@ private[akka] trait RememberingShard {
val name = URLEncoder.encode(id, "utf-8")
context.child(name) match {
case Some(actor) =>
touchLastMessageTimestamp(id)
actor.tell(payload, snd)
case None =>
if (state.entities.contains(id)) {
// this may happen when entity is stopped without passivation
require(!messageBuffers.contains(id), s"Message buffers contains id [$id].")
getOrCreateEntity(id).tell(payload, snd)
} else {
@ -740,8 +758,8 @@ private[akka] class DDataShard(
with Stash
with ActorLogging {
import ShardRegion.EntityId
import Shard._
import ShardRegion.EntityId
import settings.tuningParameters._
private val readMajority = ReadMajority(settings.tuningParameters.waitingForStateTimeout, majorityMinCap)
@ -759,10 +777,12 @@ private[akka] class DDataShard(
// configuration on each node.
private val numberOfKeys = 5
private val stateKeys: Array[ORSetKey[EntityId]] =
Array.tabulate(numberOfKeys)(i => ORSetKey[EntityId](s"shard-${typeName}-${shardId}-$i"))
Array.tabulate(numberOfKeys)(i => ORSetKey[EntityId](s"shard-$typeName-$shardId-$i"))
private var waiting = true
private def key(entityId: EntityId): ORSetKey[EntityId] = {
val i = (math.abs(entityId.hashCode % numberOfKeys))
val i = math.abs(entityId.hashCode % numberOfKeys)
stateKeys(i)
}
@ -773,11 +793,17 @@ private[akka] class DDataShard(
}
private def getState(): Unit = {
(0 until numberOfKeys).map { i =>
(0 until numberOfKeys).foreach { i =>
replicator ! Get(stateKeys(i), readMajority, Some(i))
}
}
override protected[akka] def aroundReceive(rcv: Receive, msg: Any): Unit = {
super.aroundReceive(rcv, msg)
if (!waiting)
unstash() // unstash one message
}
override def receive = waitingForState(Set.empty)
// This state will stash all commands
@ -807,24 +833,26 @@ private[akka] class DDataShard(
receiveOne(i)
case _ =>
log.debug("Stashing while waiting for DDataShard initial state")
stash()
}
}
private def recoveryCompleted(): Unit = {
log.debug("DDataShard recovery completed shard [{}] with [{}] entities", shardId, state.entities.size)
waiting = false
context.parent ! ShardInitialized(shardId)
context.become(receiveCommand)
restartRememberedEntities()
unstashAll()
}
override def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = {
waiting = true
context.become(waitingForUpdate(event, handler), discardOld = false)
sendUpdate(event, retryCount = 1)
}
private def sendUpdate(evt: StateChange, retryCount: Int) = {
private def sendUpdate(evt: StateChange, retryCount: Int): Unit = {
replicator ! Update(key(evt.entityId), ORSet.empty[EntityId], writeMajority, Some((evt, retryCount))) { existing =>
evt match {
case EntityStarted(id) => existing :+ id
@ -837,9 +865,9 @@ private[akka] class DDataShard(
private def waitingForUpdate[E <: StateChange](evt: E, afterUpdateCallback: E => Unit): Receive = {
case UpdateSuccess(_, Some((`evt`, _))) =>
log.debug("The DDataShard state was successfully updated with {}", evt)
waiting = false
context.unbecome()
afterUpdateCallback(evt)
unstashAll()
case UpdateTimeout(_, Some((`evt`, retryCount: Int))) =>
if (retryCount == maxUpdateAttempts) {
@ -861,16 +889,73 @@ private[akka] class DDataShard(
sendUpdate(evt, retryCount + 1)
}
case StoreFailure(_, Some((`evt`, _))) =>
log.error(
"The DDataShard was unable to update state with event {} due to StoreFailure. " +
"Shard will be restarted after backoff.",
evt)
context.stop(self)
case ModifyFailure(_, error, cause, Some((`evt`, _))) =>
log.error(
cause,
"The DDataShard was unable to update state with error {} and event {}. Shard will be restarted",
error,
evt)
"The DDataShard was unable to update state with event {} due to ModifyFailure. " +
"Shard will be restarted. {}",
evt,
error)
throw cause
// TODO what can this actually be? We're unitialized in the ShardRegion
case _ => stash()
// below cases should handle same messages as in Shard.receiveCommand
case _: Terminated => stash()
case _: CoordinatorMessage => stash()
case _: ShardCommand => stash()
case _: ShardRegion.StartEntity => stash()
case _: ShardRegion.StartEntityAck => stash()
case _: ShardRegionCommand => stash()
case msg: ShardQuery => receiveShardQuery(msg)
case PassivateIdleTick => stash()
case msg: LeaseLost => receiveLeaseLost(msg)
case msg if extractEntityId.isDefinedAt(msg) => deliverOrBufferMessage(msg, evt)
case msg =>
// shouldn't be any other message types, but just in case
log.debug("Stashing unexpected message [{}] while waiting for DDataShard update of {}", msg.getClass, evt)
stash()
}
/**
* If the message is for the same entity as we are waiting for the update it will be added to
* its messageBuffer, which will be sent after the update has completed.
*
* If the message is for another entity that is already started (and not in progress of passivating)
* it will be delivered immediately.
*
* Otherwise it will be stashed, and processed after the update has been completed.
*/
private def deliverOrBufferMessage(msg: Any, waitingForUpdateEvent: StateChange): Unit = {
val (id, payload) = extractEntityId(msg)
if (id == null || id == "") {
log.warning("Id must not be empty, dropping message [{}]", msg.getClass.getName)
context.system.deadLetters ! msg
} else {
payload match {
case _: ShardRegion.StartEntity =>
// in case it was wrapped, used in Typed
stash()
case _ =>
if (id == waitingForUpdateEvent.entityId) {
appendToMessageBuffer(id, msg, sender())
} else {
val name = URLEncoder.encode(id, "utf-8")
// messageBuffers.contains(id) when passivation is in progress
if (!messageBuffers.contains(id) && context.child(name).nonEmpty) {
deliverTo(id, msg, payload, sender())
} else {
log.debug("Stashing to [{}] while waiting for DDataShard update of {}", id, waitingForUpdateEvent)
stash()
}
}
}
}
}
}
@ -887,9 +972,10 @@ object EntityRecoveryStrategy {
trait EntityRecoveryStrategy {
import ShardRegion.EntityId
import scala.concurrent.Future
import ShardRegion.EntityId
def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]]
}