/* * Copyright (C) 2009-2019 Lightbend Inc. */ package akka.cluster.sharding import java.net.URLEncoder import scala.concurrent.Future import scala.concurrent.duration._ import akka.actor.Actor import akka.actor.ActorLogging import akka.actor.ActorRef import akka.actor.ActorSystem import akka.actor.DeadLetterSuppression import akka.actor.Deploy import akka.actor.NoSerializationVerificationNeeded import akka.actor.Props import akka.actor.Stash import akka.actor.Terminated import akka.actor.Timers import akka.cluster.Cluster import akka.cluster.ddata.ORSet import akka.cluster.ddata.ORSetKey import akka.cluster.ddata.Replicator._ import akka.cluster.ddata.SelfUniqueAddress import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage import akka.cluster.sharding.ShardRegion.ShardInitialized import akka.cluster.sharding.ShardRegion.ShardRegionCommand import akka.coordination.lease.scaladsl.Lease import akka.coordination.lease.scaladsl.LeaseProvider import akka.pattern.pipe import akka.persistence._ import akka.util.MessageBufferMap import akka.util.PrettyDuration._ import akka.util.unused /** * INTERNAL API * * @see [[ClusterSharding$ ClusterSharding extension]] */ private[akka] object Shard { import ShardRegion.EntityId /** * A Shard command */ sealed trait ShardCommand /** * When remembering entities and the entity stops without issuing a `Passivate`, we * restart it after a back off using this message. */ final case class RestartEntity(entity: EntityId) extends ShardCommand /** * When initialising a shard with remember entities enabled the following message is used * to restart batches of entity actors at a time. */ final case class RestartEntities(entity: Set[EntityId]) extends ShardCommand /** * A case class which represents a state change for the Shard */ sealed trait StateChange extends ClusterShardingSerializable { val entityId: EntityId } /** * A query for information about the shard */ sealed trait ShardQuery /** * `State` change for starting an entity in this `Shard` */ @SerialVersionUID(1L) final case class EntityStarted(entityId: EntityId) extends StateChange /** * `State` change for an entity which has terminated. */ @SerialVersionUID(1L) final case class EntityStopped(entityId: EntityId) extends StateChange @SerialVersionUID(1L) case object GetCurrentShardState extends ShardQuery @SerialVersionUID(1L) final case class CurrentShardState(shardId: ShardRegion.ShardId, entityIds: Set[EntityId]) @SerialVersionUID(1L) case object GetShardStats extends ShardQuery with ClusterShardingSerializable @SerialVersionUID(1L) final case class ShardStats(shardId: ShardRegion.ShardId, entityCount: Int) extends ClusterShardingSerializable final case class LeaseAcquireResult(acquired: Boolean, reason: Option[Throwable]) extends DeadLetterSuppression final case class LeaseLost(reason: Option[Throwable]) extends DeadLetterSuppression final case object LeaseRetry extends DeadLetterSuppression private val LeaseRetryTimer = "lease-retry" object State { val Empty = State() } /** * Persistent state of the Shard. */ @SerialVersionUID(1L) final case class State private[akka] (entities: Set[EntityId] = Set.empty) extends ClusterShardingSerializable /** * Factory method for the [[akka.actor.Props]] of the [[Shard]] actor. * If `settings.rememberEntities` is enabled the `PersistentShard` * subclass is used, otherwise `Shard`. */ def props( typeName: String, shardId: ShardRegion.ShardId, entityProps: String => Props, settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, handOffStopMessage: Any, replicator: ActorRef, majorityMinCap: Int): Props = { if (settings.rememberEntities && settings.stateStoreMode == ClusterShardingSettings.StateStoreModeDData) { Props( new DDataShard( typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage, replicator, majorityMinCap)).withDeploy(Deploy.local) } else if (settings.rememberEntities && settings.stateStoreMode == ClusterShardingSettings.StateStoreModePersistence) Props( new PersistentShard( typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)).withDeploy(Deploy.local) else Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)) .withDeploy(Deploy.local) } case object PassivateIdleTick extends NoSerializationVerificationNeeded } /** * INTERNAL API * * This actor creates children entity actors on demand that it is told to be * responsible for. * * @see [[ClusterSharding$ ClusterSharding extension]] */ private[akka] class Shard( typeName: String, shardId: ShardRegion.ShardId, entityProps: String => Props, settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, @unused extractShardId: ShardRegion.ExtractShardId, handOffStopMessage: Any) extends Actor with ActorLogging with Timers { import Shard._ import ShardCoordinator.Internal.HandOff import ShardCoordinator.Internal.ShardStopped import ShardRegion.EntityId import ShardRegion.Msg import ShardRegion.Passivate import ShardRegion.ShardInitialized import ShardRegion.handOffStopperProps import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage import akka.cluster.sharding.ShardRegion.ShardRegionCommand import settings.tuningParameters._ var state = State.Empty var idByRef = Map.empty[ActorRef, EntityId] var refById = Map.empty[EntityId, ActorRef] var lastMessageTimestamp = Map.empty[EntityId, Long] var passivating = Set.empty[ActorRef] val messageBuffers = new MessageBufferMap[EntityId] private var handOffStopper: Option[ActorRef] = None import context.dispatcher val passivateIdleTask = if (settings.passivateIdleEntityAfter > Duration.Zero) { val idleInterval = settings.passivateIdleEntityAfter / 2 Some(context.system.scheduler.schedule(idleInterval, idleInterval, self, PassivateIdleTick)) } else { None } private val lease = settings.leaseSettings.map( ls => LeaseProvider(context.system).getLease( s"${context.system.name}-shard-$typeName-$shardId", ls.leaseImplementation, Cluster(context.system).selfAddress.hostPort)) private val leaseRetryInterval = settings.leaseSettings match { case Some(l) => l.leaseRetryInterval case None => 5.seconds // not used } override def preStart(): Unit = { acquireLeaseIfNeeded() } /** * Will call onLeaseAcquired when completed, also when lease isn't used */ def acquireLeaseIfNeeded(): Unit = { lease match { case Some(l) => tryGetLease(l) context.become(awaitingLease()) case None => onLeaseAcquired() } } // Override to execute logic once the lease has been acquired // Will be called on the actor thread def onLeaseAcquired(): Unit = { log.debug("Shard initialized") context.parent ! ShardInitialized(shardId) context.become(receiveCommand) } private def tryGetLease(l: Lease) = { log.info("Acquiring lease {}", l.settings) pipe(l.acquire(reason => self ! LeaseLost(reason)).map(r => LeaseAcquireResult(r, None)).recover { case t => LeaseAcquireResult(acquired = false, Some(t)) }).to(self) } def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = handler(event) def receive = receiveCommand // Don't send back ShardInitialized so that messages are buffered in the ShardRegion // while awaiting the lease private def awaitingLease(): Receive = { case LeaseAcquireResult(true, _) => log.debug("Acquired lease") onLeaseAcquired() case LeaseAcquireResult(false, None) => log.error( "Failed to get lease for shard type [{}] id [{}]. Retry in {}", typeName, shardId, leaseRetryInterval.pretty) timers.startSingleTimer(LeaseRetryTimer, LeaseRetry, leaseRetryInterval) case LeaseAcquireResult(false, Some(t)) => log.error( t, "Failed to get lease for shard type [{}] id [{}]. Retry in {}", typeName, shardId, leaseRetryInterval) timers.startSingleTimer(LeaseRetryTimer, LeaseRetry, leaseRetryInterval) case LeaseRetry => tryGetLease(lease.get) case ll: LeaseLost => receiveLeaseLost(ll) } def receiveCommand: Receive = { case Terminated(ref) => receiveTerminated(ref) case msg: CoordinatorMessage => receiveCoordinatorMessage(msg) case msg: ShardCommand => receiveShardCommand(msg) case msg: ShardRegion.StartEntity => receiveStartEntity(msg) case msg: ShardRegion.StartEntityAck => receiveStartEntityAck(msg) case msg: ShardRegionCommand => receiveShardRegionCommand(msg) case msg: ShardQuery => receiveShardQuery(msg) case PassivateIdleTick => passivateIdleEntities() case msg: LeaseLost => receiveLeaseLost(msg) case msg if extractEntityId.isDefinedAt(msg) => deliverMessage(msg, sender()) } def receiveLeaseLost(msg: LeaseLost): Unit = { // The shard region will re-create this when it receives a message for this shard log.error("Shard type [{}] id [{}] lease lost. Reason: {}", typeName, shardId, msg.reason) // Stop entities ASAP rather than send termination message context.stop(self) } private def receiveShardCommand(msg: ShardCommand): Unit = msg match { // those are only used with remembering entities case RestartEntity(id) => getOrCreateEntity(id) case RestartEntities(ids) => restartEntities(ids) } private def receiveStartEntity(start: ShardRegion.StartEntity): Unit = { val requester = sender() log.debug("Got a request from [{}] to start entity [{}] in shard [{}]", requester, start.entityId, shardId) touchLastMessageTimestamp(start.entityId) if (state.entities(start.entityId)) { getOrCreateEntity(start.entityId) requester ! ShardRegion.StartEntityAck(start.entityId, shardId) } else { processChange(EntityStarted(start.entityId)) { evt => getOrCreateEntity(start.entityId) sendMsgBuffer(evt) requester ! ShardRegion.StartEntityAck(start.entityId, shardId) } } } private def receiveStartEntityAck(ack: ShardRegion.StartEntityAck): Unit = { if (ack.shardId != shardId && state.entities.contains(ack.entityId)) { log.debug("Entity [{}] previously owned by shard [{}] started in shard [{}]", ack.entityId, shardId, ack.shardId) processChange(EntityStopped(ack.entityId)) { _ => state = state.copy(state.entities - ack.entityId) messageBuffers.remove(ack.entityId) } } } private def restartEntities(ids: Set[EntityId]): Unit = { context.actorOf(RememberEntityStarter.props(context.parent, ids, settings, sender())) } private def receiveShardRegionCommand(msg: ShardRegionCommand): Unit = msg match { case Passivate(stopMessage) => passivate(sender(), stopMessage) case _ => unhandled(msg) } private def receiveCoordinatorMessage(msg: CoordinatorMessage): Unit = msg match { case HandOff(`shardId`) => handOff(sender()) case HandOff(shard) => log.warning("Shard [{}] can not hand off for another Shard [{}]", shardId, shard) case _ => unhandled(msg) } def receiveShardQuery(msg: ShardQuery): Unit = msg match { case GetCurrentShardState => sender() ! CurrentShardState(shardId, refById.keySet) case GetShardStats => sender() ! ShardStats(shardId, state.entities.size) } private def handOff(replyTo: ActorRef): Unit = handOffStopper match { case Some(_) => log.warning("HandOff shard [{}] received during existing handOff", shardId) case None => log.debug("HandOff shard [{}]", shardId) if (state.entities.nonEmpty) { val entityHandOffTimeout = (settings.tuningParameters.handOffTimeout - 5.seconds).max(1.seconds) handOffStopper = Some( context.watch(context.actorOf( handOffStopperProps(shardId, replyTo, idByRef.keySet, handOffStopMessage, entityHandOffTimeout)))) //During hand off we only care about watching for termination of the hand off stopper context.become { case Terminated(ref) => receiveTerminated(ref) } } else { replyTo ! ShardStopped(shardId) context.stop(self) } } private def receiveTerminated(ref: ActorRef): Unit = { if (handOffStopper.contains(ref)) context.stop(self) else if (idByRef.contains(ref) && handOffStopper.isEmpty) entityTerminated(ref) } def entityTerminated(ref: ActorRef): Unit = { val id = idByRef(ref) idByRef -= ref refById -= id if (passivateIdleTask.isDefined) { lastMessageTimestamp -= id } if (messageBuffers.getOrEmpty(id).nonEmpty) { log.debug("Starting entity [{}] again, there are buffered messages for it", id) sendMsgBuffer(EntityStarted(id)) } else { processChange(EntityStopped(id))(passivateCompleted) } passivating = passivating - ref } private def passivate(entity: ActorRef, stopMessage: Any): Unit = { idByRef.get(entity) match { case Some(id) => if (!messageBuffers.contains(id)) { passivating = passivating + entity messageBuffers.add(id) entity ! stopMessage } else { log.debug("Passivation already in progress for {}. Not sending stopMessage back to entity.", entity) } case None => log.debug("Unknown entity {}. Not sending stopMessage back to entity.", entity) } } def touchLastMessageTimestamp(id: EntityId): Unit = { if (passivateIdleTask.isDefined) { lastMessageTimestamp = lastMessageTimestamp.updated(id, System.nanoTime()) } } private def passivateIdleEntities(): Unit = { val deadline = System.nanoTime() - settings.passivateIdleEntityAfter.toNanos val refsToPassivate = lastMessageTimestamp.collect { case (entityId, lastMessageTimestamp) if lastMessageTimestamp < deadline => refById(entityId) } if (refsToPassivate.nonEmpty) { log.debug("Passivating [{}] idle entities", refsToPassivate.size) refsToPassivate.foreach(passivate(_, handOffStopMessage)) } } // EntityStopped handler def passivateCompleted(event: EntityStopped): Unit = { val hasBufferedMessages = messageBuffers.getOrEmpty(event.entityId).nonEmpty state = state.copy(state.entities - event.entityId) if (hasBufferedMessages) { log.debug( "Entity stopped after passivation [{}], but will be started again due to buffered messages.", event.entityId) processChange(EntityStarted(event.entityId))(sendMsgBuffer) } else { log.debug("Entity stopped after passivation [{}]", event.entityId) messageBuffers.remove(event.entityId) } } // EntityStarted handler def sendMsgBuffer(event: EntityStarted): Unit = { //Get the buffered messages and remove the buffer val messages = messageBuffers.getOrEmpty(event.entityId) messageBuffers.remove(event.entityId) if (messages.nonEmpty) { log.debug("Sending message buffer for entity [{}] ([{}] messages)", event.entityId, messages.size) getOrCreateEntity(event.entityId) //Now there is no deliveryBuffer we can try to redeliver // and as the child exists, the message will be directly forwarded messages.foreach { case (msg, snd) => deliverMessage(msg, snd) } } } def deliverMessage(msg: Any, snd: ActorRef): Unit = { val (id, payload) = extractEntityId(msg) if (id == null || id == "") { log.warning("Id must not be empty, dropping message [{}]", msg.getClass.getName) context.system.deadLetters ! msg } else { payload match { case start: ShardRegion.StartEntity => // in case it was wrapped, used in Typed receiveStartEntity(start) case _ => if (messageBuffers.contains(id)) appendToMessageBuffer(id, msg, snd) else deliverTo(id, msg, payload, snd) } } } def appendToMessageBuffer(id: EntityId, msg: Any, snd: ActorRef): Unit = { if (messageBuffers.totalSize >= bufferSize) { log.debug("Buffer is full, dropping message for entity [{}]", id) context.system.deadLetters ! msg } else { log.debug("Message for entity [{}] buffered", id) messageBuffers.append(id, msg, snd) } } def deliverTo(id: EntityId, @unused msg: Any, payload: Msg, snd: ActorRef): Unit = { touchLastMessageTimestamp(id) getOrCreateEntity(id).tell(payload, snd) } def getOrCreateEntity(id: EntityId): ActorRef = { val name = URLEncoder.encode(id, "utf-8") context.child(name) match { case Some(child) => child case None => log.debug("Starting entity [{}] in shard [{}]", id, shardId) val a = context.watch(context.actorOf(entityProps(id), name)) idByRef = idByRef.updated(a, id) refById = refById.updated(id, a) state = state.copy(state.entities + id) touchLastMessageTimestamp(id) a } } override def postStop(): Unit = { passivateIdleTask.foreach(_.cancel()) } } private[akka] object RememberEntityStarter { def props(region: ActorRef, ids: Set[ShardRegion.EntityId], settings: ClusterShardingSettings, requestor: ActorRef) = Props(new RememberEntityStarter(region, ids, settings, requestor)) private case object Tick extends NoSerializationVerificationNeeded } /** * INTERNAL API: Actor responsible for starting entities when rememberEntities is enabled */ private[akka] class RememberEntityStarter( region: ActorRef, ids: Set[ShardRegion.EntityId], settings: ClusterShardingSettings, requestor: ActorRef) extends Actor with ActorLogging { import RememberEntityStarter.Tick import context.dispatcher var waitingForAck = ids sendStart(ids) val tickTask = { val resendInterval = settings.tuningParameters.retryInterval context.system.scheduler.schedule(resendInterval, resendInterval, self, Tick) } def sendStart(ids: Set[ShardRegion.EntityId]): Unit = { // these go through the region rather the directly to the shard // so that shard mapping changes are picked up ids.foreach(id => region ! ShardRegion.StartEntity(id)) } override def receive: Receive = { case ack: ShardRegion.StartEntityAck => waitingForAck -= ack.entityId // inform whoever requested the start that it happened requestor ! ack if (waitingForAck.isEmpty) context.stop(self) case Tick => sendStart(waitingForAck) } override def postStop(): Unit = { tickTask.cancel() } } /** * INTERNAL API: Common things for PersistentShard and DDataShard */ private[akka] trait RememberingShard { selfType: Shard => import Shard._ import ShardRegion.EntityId import ShardRegion.Msg import akka.pattern.pipe protected val settings: ClusterShardingSettings protected val rememberedEntitiesRecoveryStrategy: EntityRecoveryStrategy = { import settings.tuningParameters._ entityRecoveryStrategy match { case "all" => EntityRecoveryStrategy.allStrategy() case "constant" => EntityRecoveryStrategy.constantStrategy( context.system, entityRecoveryConstantRateStrategyFrequency, entityRecoveryConstantRateStrategyNumberOfEntities) } } protected def restartRememberedEntities(): Unit = { rememberedEntitiesRecoveryStrategy.recoverEntities(state.entities).foreach { scheduledRecovery => import context.dispatcher scheduledRecovery.filter(_.nonEmpty).map(RestartEntities).pipeTo(self) } } override def entityTerminated(ref: ActorRef): Unit = { import settings.tuningParameters._ val id = idByRef(ref) idByRef -= ref refById -= id if (passivateIdleTask.isDefined) { lastMessageTimestamp -= id } if (messageBuffers.getOrEmpty(id).nonEmpty) { //Note; because we're not persisting the EntityStopped, we don't need // to persist the EntityStarted either. log.debug("Starting entity [{}] again, there are buffered messages for it", id) sendMsgBuffer(EntityStarted(id)) } else { if (!passivating.contains(ref)) { log.debug("Entity [{}] stopped without passivating, will restart after backoff", id) // note that it's not removed from state here, will be started again via RestartEntity import context.dispatcher context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id)) } else processChange(EntityStopped(id))(passivateCompleted) } passivating = passivating - ref } override def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = { val name = URLEncoder.encode(id, "utf-8") context.child(name) match { case Some(actor) => touchLastMessageTimestamp(id) actor.tell(payload, snd) case None => if (state.entities.contains(id)) { // this may happen when entity is stopped without passivation require(!messageBuffers.contains(id), s"Message buffers contains id [$id].") getOrCreateEntity(id).tell(payload, snd) } else { //Note; we only do this if remembering, otherwise the buffer is an overhead messageBuffers.append(id, msg, snd) processChange(EntityStarted(id))(sendMsgBuffer) } } } } /** * INTERNAL API * * This actor creates children entity actors on demand that it is told to be * responsible for. It is used when `rememberEntities` is enabled and * `state-store-mode=persistence`. * * @see [[ClusterSharding$ ClusterSharding extension]] */ private[akka] class PersistentShard( typeName: String, shardId: ShardRegion.ShardId, entityProps: String => Props, override val settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, handOffStopMessage: Any) extends Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage) with RememberingShard with PersistentActor with ActorLogging { import Shard._ import settings.tuningParameters._ override def preStart(): Unit = { // override to not acquire the lease on start up, acquire after persistent recovery } override def persistenceId = s"/sharding/${typeName}Shard/$shardId" override def journalPluginId: String = settings.journalPluginId override def snapshotPluginId: String = settings.snapshotPluginId override def receive = receiveCommand override def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = { saveSnapshotWhenNeeded() persist(event)(handler) } def saveSnapshotWhenNeeded(): Unit = { if (lastSequenceNr % snapshotAfter == 0 && lastSequenceNr != 0) { log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr) saveSnapshot(state) } } override def receiveRecover: Receive = { case EntityStarted(id) => state = state.copy(state.entities + id) case EntityStopped(id) => state = state.copy(state.entities - id) case SnapshotOffer(_, snapshot: State) => state = snapshot case RecoveryCompleted => acquireLeaseIfNeeded() // onLeaseAcquired called when completed log.debug("PersistentShard recovery completed shard [{}] with [{}] entities", shardId, state.entities.size) } override def onLeaseAcquired(): Unit = { log.debug("Shard initialized") context.parent ! ShardInitialized(shardId) context.become(receiveCommand) restartRememberedEntities() unstashAll() } override def receiveCommand: Receive = ({ case e: SaveSnapshotSuccess => log.debug("PersistentShard snapshot saved successfully") internalDeleteMessagesBeforeSnapshot(e, keepNrOfBatches, snapshotAfter) case SaveSnapshotFailure(_, reason) => log.warning("PersistentShard snapshot failure: [{}]", reason.getMessage) case DeleteMessagesSuccess(toSequenceNr) => val deleteTo = toSequenceNr - 1 val deleteFrom = math.max(0, deleteTo - (keepNrOfBatches * snapshotAfter)) log.debug( "PersistentShard messages to [{}] deleted successfully. Deleting snapshots from [{}] to [{}]", toSequenceNr, deleteFrom, deleteTo) deleteSnapshots(SnapshotSelectionCriteria(minSequenceNr = deleteFrom, maxSequenceNr = deleteTo)) case DeleteMessagesFailure(reason, toSequenceNr) => log.warning("PersistentShard messages to [{}] deletion failure: [{}]", toSequenceNr, reason.getMessage) case DeleteSnapshotsSuccess(m) => log.debug("PersistentShard snapshots matching [{}] deleted successfully", m) case DeleteSnapshotsFailure(m, reason) => log.warning("PersistentShard snapshots matching [{}] deletion failure: [{}]", m, reason.getMessage) }: Receive).orElse(super.receiveCommand) } /** * INTERNAL API * * This actor creates children entity actors on demand that it is told to be * responsible for. It is used when `rememberEntities` is enabled and * `state-store-mode=ddata`. * * @see [[ClusterSharding$ ClusterSharding extension]] */ private[akka] class DDataShard( typeName: String, shardId: ShardRegion.ShardId, entityProps: String => Props, override val settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, handOffStopMessage: Any, replicator: ActorRef, majorityMinCap: Int) extends Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage) with RememberingShard with Stash with ActorLogging { import Shard._ import ShardRegion.EntityId import settings.tuningParameters._ private val readMajority = ReadMajority(settings.tuningParameters.waitingForStateTimeout, majorityMinCap) private val writeMajority = WriteMajority(settings.tuningParameters.updatingStateTimeout, majorityMinCap) private val maxUpdateAttempts = 3 implicit private val node = Cluster(context.system) implicit private val selfUniqueAddress = SelfUniqueAddress(node.selfUniqueAddress) // The default maximum-frame-size is 256 KiB with Artery. // When using entity identifiers with 36 character strings (e.g. UUID.randomUUID). // By splitting the elements over 5 keys we can support 10000 entities per shard. // The Gossip message size of 5 ORSet with 2000 ids is around 200 KiB. // This is by intention not configurable because it's important to have the same // configuration on each node. private val numberOfKeys = 5 private val stateKeys: Array[ORSetKey[EntityId]] = Array.tabulate(numberOfKeys)(i => ORSetKey[EntityId](s"shard-$typeName-$shardId-$i")) private var waiting = true private def key(entityId: EntityId): ORSetKey[EntityId] = { val i = math.abs(entityId.hashCode % numberOfKeys) stateKeys(i) } override def onLeaseAcquired(): Unit = { log.info("Lease Acquired. Getting state from DData") getState() context.become(waitingForState(Set.empty)) } private def getState(): Unit = { (0 until numberOfKeys).foreach { i => replicator ! Get(stateKeys(i), readMajority, Some(i)) } } override protected[akka] def aroundReceive(rcv: Receive, msg: Any): Unit = { super.aroundReceive(rcv, msg) if (!waiting) unstash() // unstash one message } override def receive = waitingForState(Set.empty) // This state will stash all commands private def waitingForState(gotKeys: Set[Int]): Receive = { def receiveOne(i: Int): Unit = { val newGotKeys = gotKeys + i if (newGotKeys.size == numberOfKeys) { recoveryCompleted() } else context.become(waitingForState(newGotKeys)) } { case g @ GetSuccess(_, Some(i: Int)) => val key = stateKeys(i) state = state.copy(entities = state.entities.union(g.get(key).elements)) receiveOne(i) case GetFailure(_, _) => log.error( "The DDataShard was unable to get an initial state within 'waiting-for-state-timeout': {} millis", waitingForStateTimeout.toMillis) // parent ShardRegion supervisor will notice that it terminated and will start it again, after backoff context.stop(self) case NotFound(_, Some(i: Int)) => receiveOne(i) case _ => log.debug("Stashing while waiting for DDataShard initial state") stash() } } private def recoveryCompleted(): Unit = { log.debug("DDataShard recovery completed shard [{}] with [{}] entities", shardId, state.entities.size) waiting = false context.parent ! ShardInitialized(shardId) context.become(receiveCommand) restartRememberedEntities() } override def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = { waiting = true context.become(waitingForUpdate(event, handler), discardOld = false) sendUpdate(event, retryCount = 1) } private def sendUpdate(evt: StateChange, retryCount: Int): Unit = { replicator ! Update(key(evt.entityId), ORSet.empty[EntityId], writeMajority, Some((evt, retryCount))) { existing => evt match { case EntityStarted(id) => existing :+ id case EntityStopped(id) => existing.remove(id) } } } // this state will stash all messages until it receives UpdateSuccess private def waitingForUpdate[E <: StateChange](evt: E, afterUpdateCallback: E => Unit): Receive = { case UpdateSuccess(_, Some((`evt`, _))) => log.debug("The DDataShard state was successfully updated with {}", evt) waiting = false context.unbecome() afterUpdateCallback(evt) case UpdateTimeout(_, Some((`evt`, retryCount: Int))) => if (retryCount == maxUpdateAttempts) { // parent ShardRegion supervisor will notice that it terminated and will start it again, after backoff log.error( "The DDataShard was unable to update state after {} attempts, within 'updating-state-timeout'={} millis, event={}. " + "Shard will be restarted after backoff.", maxUpdateAttempts, updatingStateTimeout.toMillis, evt) context.stop(self) } else { log.warning( "The DDataShard was unable to update state, attempt {} of {}, within 'updating-state-timeout'={} millis, event={}", retryCount, maxUpdateAttempts, updatingStateTimeout.toMillis, evt) sendUpdate(evt, retryCount + 1) } case StoreFailure(_, Some((`evt`, _))) => log.error( "The DDataShard was unable to update state with event {} due to StoreFailure. " + "Shard will be restarted after backoff.", evt) context.stop(self) case ModifyFailure(_, error, cause, Some((`evt`, _))) => log.error( cause, "The DDataShard was unable to update state with event {} due to ModifyFailure. " + "Shard will be restarted. {}", evt, error) throw cause // below cases should handle same messages as in Shard.receiveCommand case _: Terminated => stash() case _: CoordinatorMessage => stash() case _: ShardCommand => stash() case _: ShardRegion.StartEntity => stash() case _: ShardRegion.StartEntityAck => stash() case _: ShardRegionCommand => stash() case msg: ShardQuery => receiveShardQuery(msg) case PassivateIdleTick => stash() case msg: LeaseLost => receiveLeaseLost(msg) case msg if extractEntityId.isDefinedAt(msg) => deliverOrBufferMessage(msg, evt) case msg => // shouldn't be any other message types, but just in case log.debug("Stashing unexpected message [{}] while waiting for DDataShard update of {}", msg.getClass, evt) stash() } /** * If the message is for the same entity as we are waiting for the update it will be added to * its messageBuffer, which will be sent after the update has completed. * * If the message is for another entity that is already started (and not in progress of passivating) * it will be delivered immediately. * * Otherwise it will be stashed, and processed after the update has been completed. */ private def deliverOrBufferMessage(msg: Any, waitingForUpdateEvent: StateChange): Unit = { val (id, payload) = extractEntityId(msg) if (id == null || id == "") { log.warning("Id must not be empty, dropping message [{}]", msg.getClass.getName) context.system.deadLetters ! msg } else { payload match { case _: ShardRegion.StartEntity => // in case it was wrapped, used in Typed stash() case _ => if (id == waitingForUpdateEvent.entityId) { appendToMessageBuffer(id, msg, sender()) } else { val name = URLEncoder.encode(id, "utf-8") // messageBuffers.contains(id) when passivation is in progress if (!messageBuffers.contains(id) && context.child(name).nonEmpty) { deliverTo(id, msg, payload, sender()) } else { log.debug("Stashing to [{}] while waiting for DDataShard update of {}", id, waitingForUpdateEvent) stash() } } } } } } object EntityRecoveryStrategy { def allStrategy(): EntityRecoveryStrategy = new AllAtOnceEntityRecoveryStrategy() def constantStrategy( actorSystem: ActorSystem, frequency: FiniteDuration, numberOfEntities: Int): EntityRecoveryStrategy = new ConstantRateEntityRecoveryStrategy(actorSystem, frequency, numberOfEntities) } trait EntityRecoveryStrategy { import scala.concurrent.Future import ShardRegion.EntityId def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] } final class AllAtOnceEntityRecoveryStrategy extends EntityRecoveryStrategy { import ShardRegion.EntityId override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = if (entities.isEmpty) Set.empty else Set(Future.successful(entities)) } final class ConstantRateEntityRecoveryStrategy( actorSystem: ActorSystem, frequency: FiniteDuration, numberOfEntities: Int) extends EntityRecoveryStrategy { import ShardRegion.EntityId import actorSystem.dispatcher import akka.pattern.after override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = entities .grouped(numberOfEntities) .foldLeft((frequency, Set[Future[Set[EntityId]]]())) { case ((interval, scheduledEntityIds), entityIds) => (interval + frequency, scheduledEntityIds + scheduleEntities(interval, entityIds)) } ._2 private def scheduleEntities(interval: FiniteDuration, entityIds: Set[EntityId]) = after(interval, actorSystem.scheduler)(Future.successful[Set[EntityId]](entityIds)) }