From 722b68e7cc61cb0d05de1112ae8b4c4e5ec89db6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Fri, 3 Apr 2020 14:57:49 +0200 Subject: [PATCH] Refactor remember entitites in shards (#28776) * DData and Persistence based remember entitites refactored * Order methods in the order of init in the shard. * Some bad isolation between test cases causing problems * Test coverage for remember entities store failures * WithLogCapturing where applicable * MiMa filters * Timeouts from config for persistent remember entities * Single method for deliver, less utf-8 encoding * Include detail on write failure * Don't send message to dead letter if it is actually handled in BackOffSupervisor * Back off supervisor log format plus use warning for hitting max restarts * actor/message based spi * Missing assert that node had joined cluster --- .../internal/BackoffOnStopSupervisor.scala | 15 +- .../remember-entities-refactor.excludes | 39 + .../src/main/resources/reference.conf | 2 - .../cluster/sharding/ClusterSharding.scala | 4 +- .../sharding/ClusterShardingSettings.scala | 6 +- .../scala/akka/cluster/sharding/Shard.scala | 1078 ++++++----------- .../akka/cluster/sharding/ShardRegion.scala | 40 +- .../CustomStateStoreModeProvider.scala | 45 + .../internal/DDataRememberEntitiesStore.scala | 217 ++++ .../internal/EntityRecoveryStrategy.scala | 81 ++ .../EventSourcedRememberEntities.scala | 158 +++ .../internal/RememberEntitiesStarter.scala | 74 ++ .../internal/RememberEntitiesStore.scala | 45 + .../ClusterShardingMessageSerializer.scala | 6 +- .../sharding/ClusterShardingFailureSpec.scala | 22 +- .../MultiNodeClusterShardingSpec.scala | 2 +- .../AllAtOnceEntityRecoveryStrategySpec.scala | 1 + .../ClusterShardingInternalsSpec.scala | 5 +- .../sharding/ClusterShardingLeaseSpec.scala | 8 +- .../ConcurrentStartupShardingSpec.scala | 6 +- ...nstantRateEntityRecoveryStrategySpec.scala | 1 + .../CoordinatedShutdownShardingSpec.scala | 1 + .../sharding/GetShardTypeNamesSpec.scala | 6 +- .../InactiveEntityPassivationSpec.scala | 7 +- .../JoinConfigCompatCheckShardingSpec.scala | 6 +- .../sharding/PersistentShardSpec.scala | 55 - .../cluster/sharding/ProxyShardingSpec.scala | 8 +- .../RememberEntitiesFailureSpec.scala | 290 +++++ .../sharding/RememberEntitiesSpec.scala | 100 ++ ...emoveInternalClusterShardingDataSpec.scala | 8 +- .../cluster/sharding/ShardRegionSpec.scala | 6 +- .../akka/cluster/sharding/ShardSpec.scala | 123 -- .../cluster/sharding/ShardWithLeaseSpec.scala | 158 +++ .../cluster/sharding/SupervisionSpec.scala | 23 +- ...ClusterShardingMessageSerializerSpec.scala | 8 +- 35 files changed, 1725 insertions(+), 929 deletions(-) create mode 100644 akka-cluster-sharding/src/main/mima-filters/2.6.4.backwards.excludes/remember-entities-refactor.excludes create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/CustomStateStoreModeProvider.scala create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/DDataRememberEntitiesStore.scala create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EventSourcedRememberEntities.scala create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala create mode 100644 akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStore.scala delete mode 100644 akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala create mode 100644 akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesFailureSpec.scala create mode 100644 akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesSpec.scala delete mode 100644 akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardSpec.scala create mode 100644 akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardWithLeaseSpec.scala diff --git a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala index 51ebcc18ab..a739098802 100644 --- a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala +++ b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala @@ -8,6 +8,7 @@ import akka.actor.SupervisorStrategy.{ Directive, Escalate } import akka.actor.{ Actor, ActorLogging, OneForOneStrategy, Props, SupervisorStrategy, Terminated } import akka.annotation.InternalApi import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } +import akka.util.PrettyDuration import scala.concurrent.duration.FiniteDuration @@ -51,6 +52,7 @@ import scala.concurrent.duration.FiniteDuration case Terminated(ref) if child.contains(ref) => child = None if (finalStopMessageReceived) { + log.debug("Child terminated after final stop message, stopping supervisor") context.stop(self) } else { val maxNrOfRetries = strategy match { @@ -61,13 +63,14 @@ import scala.concurrent.duration.FiniteDuration if (maxNrOfRetries == -1 || nextRestartCount <= maxNrOfRetries) { val restartDelay = calculateDelay(restartCount, minBackoff, maxBackoff, randomFactor) + log.debug("Supervised child terminated, restarting after [{}] back off", PrettyDuration.format(restartDelay)) context.system.scheduler.scheduleOnce(restartDelay, self, StartChild) restartCount = nextRestartCount } else { - log.debug( - s"Terminating on restart #{} which exceeds max allowed restarts ({})", - nextRestartCount, - maxNrOfRetries) + log.warning( + "Supervised child exceeded max allowed number of restarts [{}] (restarded [{}] times), stopping supervisor", + maxNrOfRetries, + nextRestartCount) context.stop(self) } } @@ -86,11 +89,13 @@ import scala.concurrent.duration.FiniteDuration case None => replyWhileStopped match { case Some(r) => sender() ! r - case None => context.system.deadLetters.forward(msg) + case _ => } finalStopMessage match { case Some(fsm) if fsm(msg) => context.stop(self) case _ => + // only send to dead letters if not replied nor final-stopped + if (replyWhileStopped.isEmpty) context.system.deadLetters.forward(msg) } } } diff --git a/akka-cluster-sharding/src/main/mima-filters/2.6.4.backwards.excludes/remember-entities-refactor.excludes b/akka-cluster-sharding/src/main/mima-filters/2.6.4.backwards.excludes/remember-entities-refactor.excludes new file mode 100644 index 0000000000..3ffbbc5a86 --- /dev/null +++ b/akka-cluster-sharding/src/main/mima-filters/2.6.4.backwards.excludes/remember-entities-refactor.excludes @@ -0,0 +1,39 @@ +# all these are internals +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.RememberEntityStarter") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.DDataShard") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$StateChange") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.PersistentShard") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$State$") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$EntityStarted$") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$EntityStopped") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.RememberEntityStarter$Tick$") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$EntityStarted") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.RememberingShard") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.RememberEntityStarter$") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$State") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.Shard$EntityStopped$") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.props") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.state") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.state_=") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.idByRef") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.idByRef_=") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.refById") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.refById_=") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.lastMessageTimestamp") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.lastMessageTimestamp_=") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.passivating") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.passivating_=") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.messageBuffers") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.passivateIdleTask") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.onLeaseAcquired") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.processChange") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.cluster.sharding.Shard.passivateCompleted") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.cluster.sharding.Shard.sendMsgBuffer") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.deliverMessage") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.cluster.sharding.Shard.this") + +# not marked internal but for not intended as public (no public API use case) +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.EntityRecoveryStrategy") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.EntityRecoveryStrategy$") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.AllAtOnceEntityRecoveryStrategy") +ProblemFilters.exclude[MissingClassProblem]("akka.cluster.sharding.ConstantRateEntityRecoveryStrategy") \ No newline at end of file diff --git a/akka-cluster-sharding/src/main/resources/reference.conf b/akka-cluster-sharding/src/main/resources/reference.conf index 0c8bd53cb2..7f3027869f 100644 --- a/akka-cluster-sharding/src/main/resources/reference.conf +++ b/akka-cluster-sharding/src/main/resources/reference.conf @@ -119,11 +119,9 @@ akka.cluster.sharding { # and for a shard to get its state when remembered entities is enabled # The read from ddata is a ReadMajority, for small clusters (< majority-min-cap) every node needs to respond # so is more likely to time out if there are nodes restarting e.g. when there is a rolling re-deploy happening - # Only used when state-store-mode=ddata waiting-for-state-timeout = 2 s # Timeout of waiting for update the distributed state (update will be retried if the timeout happened) - # Only used when state-store-mode=ddata updating-state-timeout = 5 s # Timeout to wait for querying all shards for a given `ShardRegion`. diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterSharding.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterSharding.scala index b5ad320ea1..9a06557150 100755 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterSharding.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterSharding.scala @@ -716,7 +716,9 @@ private[akka] class ClusterShardingGuardian extends Actor { } private def replicator(settings: ClusterShardingSettings): ActorRef = { - if (settings.stateStoreMode == ClusterShardingSettings.StateStoreModeDData) { + if (settings.stateStoreMode == ClusterShardingSettings.StateStoreModeDData || + // FIXME for now coordinator still uses the replicator + settings.stateStoreMode == ClusterShardingSettings.StateStoreModeCustom) { // one Replicator per role replicatorByRole.get(settings.role) match { case Some(ref) => ref diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterShardingSettings.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterShardingSettings.scala index 8d33f444fc..f9d994d06e 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterShardingSettings.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ClusterShardingSettings.scala @@ -20,6 +20,7 @@ object ClusterShardingSettings { val StateStoreModePersistence = "persistence" val StateStoreModeDData = "ddata" + val StateStoreModeCustom = "custom" /** * Create settings from the default configuration @@ -300,10 +301,9 @@ final class ClusterShardingSettings( tuningParameters, coordinatorSingletonSettings) - import ClusterShardingSettings.StateStoreModeDData - import ClusterShardingSettings.StateStoreModePersistence + import ClusterShardingSettings.{ StateStoreModeCustom, StateStoreModeDData, StateStoreModePersistence } require( - stateStoreMode == StateStoreModePersistence || stateStoreMode == StateStoreModeDData, + stateStoreMode == StateStoreModePersistence || stateStoreMode == StateStoreModeDData || stateStoreMode == StateStoreModeCustom, s"Unknown 'state-store-mode' [$stateStoreMode], valid values are '$StateStoreModeDData' or '$StateStoreModePersistence'") /** If true, this node should run the shard region, otherwise just a shard proxy should started on this node. */ diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala index 9ab9d50e29..255f0c454f 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala @@ -6,13 +6,9 @@ package akka.cluster.sharding import java.net.URLEncoder -import scala.concurrent.Future -import scala.concurrent.duration._ - import akka.actor.Actor import akka.actor.ActorLogging import akka.actor.ActorRef -import akka.actor.ActorSystem import akka.actor.DeadLetterSuppression import akka.actor.Deploy import akka.actor.NoSerializationVerificationNeeded @@ -22,21 +18,22 @@ import akka.actor.Terminated import akka.actor.Timers import akka.annotation.InternalStableApi import akka.cluster.Cluster -import akka.cluster.ddata.ORSet -import akka.cluster.ddata.ORSetKey -import akka.cluster.ddata.Replicator._ -import akka.cluster.ddata.SelfUniqueAddress -import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage -import akka.cluster.sharding.ShardRegion.ShardInitialized -import akka.cluster.sharding.ShardRegion.ShardRegionCommand +import akka.cluster.sharding.internal.EntityRecoveryStrategy +import akka.cluster.sharding.internal.RememberEntitiesShardStore +import akka.cluster.sharding.internal.RememberEntitiesShardStore.GetEntities +import akka.cluster.sharding.internal.RememberEntitiesShardStoreProvider +import akka.cluster.sharding.internal.RememberEntityStarter import akka.coordination.lease.scaladsl.Lease import akka.coordination.lease.scaladsl.LeaseProvider import akka.pattern.pipe -import akka.persistence._ import akka.util.MessageBufferMap +import akka.util.OptionVal import akka.util.PrettyDuration._ import akka.util.unused +import scala.collection.immutable.Set +import scala.concurrent.duration._ + /** * INTERNAL API * @@ -48,42 +45,25 @@ private[akka] object Shard { /** * A Shard command */ - sealed trait ShardCommand + sealed trait RememberEntityCommand /** * When remembering entities and the entity stops without issuing a `Passivate`, we * restart it after a back off using this message. */ - final case class RestartEntity(entity: EntityId) extends ShardCommand + final case class RestartEntity(entity: EntityId) extends RememberEntityCommand /** * When initialising a shard with remember entities enabled the following message is used * to restart batches of entity actors at a time. */ - final case class RestartEntities(entity: Set[EntityId]) extends ShardCommand - - /** - * A case class which represents a state change for the Shard - */ - sealed trait StateChange extends ClusterShardingSerializable { - val entityId: EntityId - } + final case class RestartEntities(entity: Set[EntityId]) extends RememberEntityCommand /** * A query for information about the shard */ sealed trait ShardQuery - /** - * `State` change for starting an entity in this `Shard` - */ - @SerialVersionUID(1L) final case class EntityStarted(entityId: EntityId) extends StateChange - - /** - * `State` change for an entity which has terminated. - */ - @SerialVersionUID(1L) final case class EntityStopped(entityId: EntityId) extends StateChange - @SerialVersionUID(1L) case object GetCurrentShardState extends ShardQuery @SerialVersionUID(1L) final case class CurrentShardState(shardId: ShardRegion.ShardId, entityIds: Set[EntityId]) @@ -99,21 +79,6 @@ private[akka] object Shard { final case object LeaseRetry extends DeadLetterSuppression private val LeaseRetryTimer = "lease-retry" - object State { - val Empty = State() - } - - /** - * Persistent state of the Shard. - */ - @SerialVersionUID(1L) final case class State private[akka] (entities: Set[EntityId] = Set.empty) - extends ClusterShardingSerializable - - /** - * Factory method for the [[akka.actor.Props]] of the [[Shard]] actor. - * If `settings.rememberEntities` is enabled the `PersistentShard` - * subclass is used, otherwise `Shard`. - */ def props( typeName: String, shardId: ShardRegion.ShardId, @@ -122,37 +87,30 @@ private[akka] object Shard { extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, handOffStopMessage: Any, - replicator: ActorRef, - majorityMinCap: Int): Props = { - (if (settings.rememberEntities && settings.stateStoreMode == ClusterShardingSettings.StateStoreModeDData) { - Props( - new DDataShard( - typeName, - shardId, - entityProps, - settings, - extractEntityId, - extractShardId, - handOffStopMessage, - replicator, - majorityMinCap)) - } else if (settings.rememberEntities && settings.stateStoreMode == ClusterShardingSettings.StateStoreModePersistence) - Props( - new PersistentShard( - typeName, - shardId, - entityProps, - settings, - extractEntityId, - extractShardId, - handOffStopMessage)) - else { - Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)) - }).withDeploy(Deploy.local) - } + rememberEntitiesProvider: Option[RememberEntitiesShardStoreProvider]): Props = + Props( + new Shard( + typeName, + shardId, + entityProps, + settings, + extractEntityId, + extractShardId, + handOffStopMessage, + rememberEntitiesProvider)).withDeploy(Deploy.local) case object PassivateIdleTick extends NoSerializationVerificationNeeded + private final case class RememberedEntityIds(ids: Set[EntityId]) + private final case class RememberEntityStoreCrashed(store: ActorRef) + private final case object AsyncWriteDone + + private val RememberEntityTimeoutKey = "RememberEntityTimeout" + case class RememberEntityTimeout(operation: RememberEntitiesShardStore.Command) + + // FIXME Leaving this on while we are working on the remember entities refactor + // should it go in settings perhaps, useful for tricky sharding bugs? + final val VerboseDebug = true } /** @@ -163,6 +121,7 @@ private[akka] object Shard { * * @see [[ClusterSharding$ ClusterSharding extension]] */ +// FIXME I broke bin comp here @InternalStableApi private[akka] class Shard( typeName: String, @@ -171,16 +130,17 @@ private[akka] class Shard( settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, @unused extractShardId: ShardRegion.ExtractShardId, - handOffStopMessage: Any) + handOffStopMessage: Any, + rememberEntitiesProvider: Option[RememberEntitiesShardStoreProvider]) extends Actor with ActorLogging + with Stash with Timers { import Shard._ import ShardCoordinator.Internal.HandOff import ShardCoordinator.Internal.ShardStopped import ShardRegion.EntityId - import ShardRegion.Msg import ShardRegion.Passivate import ShardRegion.ShardInitialized import ShardRegion.handOffStopperProps @@ -188,17 +148,42 @@ private[akka] class Shard( import akka.cluster.sharding.ShardRegion.ShardRegionCommand import settings.tuningParameters._ - var state = State.Empty - var idByRef = Map.empty[ActorRef, EntityId] - var refById = Map.empty[EntityId, ActorRef] - var lastMessageTimestamp = Map.empty[EntityId, Long] - var passivating = Set.empty[ActorRef] - val messageBuffers = new MessageBufferMap[EntityId] + private val rememberEntitiesStore: Option[ActorRef] = + rememberEntitiesProvider.map { provider => + val store = context.actorOf(provider.shardStoreProps(shardId).withDeploy(Deploy.local), "RememberEntitiesStore") + context.watchWith(store, RememberEntityStoreCrashed(store)) + store + } + + private val rememberedEntitiesRecoveryStrategy: EntityRecoveryStrategy = { + import settings.tuningParameters._ + entityRecoveryStrategy match { + case "all" => EntityRecoveryStrategy.allStrategy() + case "constant" => + EntityRecoveryStrategy.constantStrategy( + context.system, + entityRecoveryConstantRateStrategyFrequency, + entityRecoveryConstantRateStrategyNumberOfEntities) + } + } + + // this will contain entity ids even if not yet started (or stopped without graceful stop) + private var entityIds: Set[EntityId] = Set.empty + + private var idByRef = Map.empty[ActorRef, EntityId] + private var refById = Map.empty[EntityId, ActorRef] + + private var lastMessageTimestamp = Map.empty[EntityId, Long] + + // that an entity is passivating it is added to the passivating set and its id is added to the message buffers + // to buffer any messages coming in during passivation + private var passivating = Set.empty[ActorRef] + private val messageBuffers = new MessageBufferMap[EntityId] private var handOffStopper: Option[ActorRef] = None import context.dispatcher - val passivateIdleTask = if (settings.shouldPassivateIdleEntities) { + private val passivateIdleTask = if (settings.shouldPassivateIdleEntities) { val idleInterval = settings.passivateIdleEntityAfter / 2 Some(context.system.scheduler.scheduleWithFixedDelay(idleInterval, idleInterval, self, PassivateIdleTick)) } else { @@ -217,49 +202,31 @@ private[akka] class Shard( case None => 5.seconds // not used } + def receive: Receive = { + case _ => throw new IllegalStateException("Default receive never expected to actually be used") + } + override def preStart(): Unit = { acquireLeaseIfNeeded() } - /** - * Will call onLeaseAcquired when completed, also when lease isn't used - */ + // ===== lease handling initialization ===== def acquireLeaseIfNeeded(): Unit = { lease match { case Some(l) => tryGetLease(l) context.become(awaitingLease()) case None => - onLeaseAcquired() + tryLoadRememberedEntities() } } - // Override to execute logic once the lease has been acquired - // Will be called on the actor thread - def onLeaseAcquired(): Unit = { - log.debug("Shard initialized") - context.parent ! ShardInitialized(shardId) - context.become(receiveCommand) - } - - private def tryGetLease(l: Lease) = { - log.info("Acquiring lease {}", l.settings) - pipe(l.acquire(reason => self ! LeaseLost(reason)).map(r => LeaseAcquireResult(r, None)).recover { - case t => LeaseAcquireResult(acquired = false, Some(t)) - }).to(self) - } - - def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = - handler(event) - - def receive: Receive = receiveCommand - // Don't send back ShardInitialized so that messages are buffered in the ShardRegion // while awaiting the lease private def awaitingLease(): Receive = { case LeaseAcquireResult(true, _) => - log.debug("Acquired lease") - onLeaseAcquired() + log.debug("Lease acquired") + tryLoadRememberedEntities() case LeaseAcquireResult(false, None) => log.error( "Failed to get lease for shard type [{}] id [{}]. Retry in {}", @@ -279,31 +246,167 @@ private[akka] class Shard( tryGetLease(lease.get) case ll: LeaseLost => receiveLeaseLost(ll) + case _ => + stash() } - def receiveCommand: Receive = { + private def tryGetLease(l: Lease): Unit = { + log.info("Acquiring lease {}", l.settings) + pipe(l.acquire(reason => self ! LeaseLost(reason)).map(r => LeaseAcquireResult(r, None)).recover { + case t => LeaseAcquireResult(acquired = false, Some(t)) + }).to(self) + } + + // ===== remember entities initialization ===== + def tryLoadRememberedEntities(): Unit = { + rememberEntitiesStore match { + case Some(store) => + log.debug("Waiting for load of entity ids using [{}] to complete", store) + store ! RememberEntitiesShardStore.GetEntities + timers.startSingleTimer( + RememberEntityTimeoutKey, + RememberEntityTimeout(GetEntities), + settings.tuningParameters.waitingForStateTimeout) + context.become(awaitingRememberedEntities()) + case None => + onEntitiesRemembered(Set.empty) + } + } + + def awaitingRememberedEntities(): Receive = { + case RememberEntitiesShardStore.RememberedEntities(entityIds) => + timers.cancel(RememberEntityTimeoutKey) + onEntitiesRemembered(entityIds) + case RememberEntityTimeout(GetEntities) => + loadingEntityIdsFailed() + case msg => + if (VerboseDebug) + log.debug("Got msg of type [{}] from [{}] while waiting for remember entitites", msg.getClass, sender()) + stash() + } + + def loadingEntityIdsFailed(): Unit = { + log.error( + "Failed to load initial entity ids from remember entities store within [{}], stopping shard for backoff and restart", + settings.tuningParameters.waitingForStateTimeout.pretty) + // parent ShardRegion supervisor will notice that it terminated and will start it again, after backoff + context.stop(self) + } + + def onEntitiesRemembered(ids: Set[EntityId]): Unit = { + log.debug("Shard initialized") + if (ids.nonEmpty) { + entityIds = ids + restartRememberedEntities(ids) + } else {} + context.parent ! ShardInitialized(shardId) + context.become(idle) + unstashAll() + } + + def restartRememberedEntities(ids: Set[EntityId]): Unit = { + log.debug( + "Shard starting [{}] remembered entities using strategy [{}]", + ids.size, + rememberedEntitiesRecoveryStrategy) + rememberedEntitiesRecoveryStrategy.recoverEntities(ids).foreach { scheduledRecovery => + import context.dispatcher + scheduledRecovery.filter(_.nonEmpty).map(RestartEntities).pipeTo(self) + } + } + + def restartEntities(ids: Set[EntityId]): Unit = { + log.debug("Restarting set of [{}] entities", ids.size) + context.actorOf(RememberEntityStarter.props(context.parent, ids, settings, sender())) + } + + // ===== shard up and running ===== + + def idle: Receive = { case Terminated(ref) => receiveTerminated(ref) case msg: CoordinatorMessage => receiveCoordinatorMessage(msg) - case msg: ShardCommand => receiveShardCommand(msg) + case msg: RememberEntityCommand => receiveRememberEntityCommand(msg) case msg: ShardRegion.StartEntity => receiveStartEntity(msg) case msg: ShardRegion.StartEntityAck => receiveStartEntityAck(msg) case msg: ShardRegionCommand => receiveShardRegionCommand(msg) case msg: ShardQuery => receiveShardQuery(msg) case PassivateIdleTick => passivateIdleEntities() case msg: LeaseLost => receiveLeaseLost(msg) - case msg if extractEntityId.isDefinedAt(msg) => deliverMessage(msg, sender()) + case msg: RememberEntityStoreCrashed => rememberEntityStoreCrashed(msg) + case msg if extractEntityId.isDefinedAt(msg) => deliverMessage(msg, sender(), OptionVal.None) + } + + def waitForAsyncWrite(entityId: EntityId, command: RememberEntitiesShardStore.Command)( + whenDone: EntityId => Unit): Unit = { + rememberEntitiesStore match { + case None => + whenDone(entityId) + + case Some(store) => + if (VerboseDebug) log.debug("Update of [{}] [{}] triggered", entityId, command) + store ! command + timers.startSingleTimer( + RememberEntityTimeoutKey, + RememberEntityTimeout(command), + // FIXME this timeout needs to match the timeout used in the ddata shard write since that tries 3 times + // and this could always fail before ddata store completes retrying writes + settings.tuningParameters.updatingStateTimeout) + + context.become { + case RememberEntitiesShardStore.UpdateDone(entityId) => + if (VerboseDebug) log.debug("Update of [{}] {} done", entityId, command) + timers.cancel(RememberEntityTimeoutKey) + whenDone(entityId) + context.become(idle) + unstashAll() + case RememberEntityTimeout(`command`) => + throw new RuntimeException( + s"Async write for entityId $entityId timed out after ${settings.tuningParameters.updatingStateTimeout.pretty}") + + // below cases should handle same messages as in idle + case _: Terminated => stash() + case _: CoordinatorMessage => stash() + case _: RememberEntityCommand => stash() + case _: ShardRegion.StartEntity => stash() + case _: ShardRegion.StartEntityAck => stash() + case _: ShardRegionCommand => stash() + case msg: ShardQuery => receiveShardQuery(msg) + case PassivateIdleTick => stash() + case msg: LeaseLost => receiveLeaseLost(msg) + case msg: RememberEntityStoreCrashed => rememberEntityStoreCrashed(msg) + case msg if extractEntityId.isDefinedAt(msg) => deliverMessage(msg, sender(), OptionVal.Some(entityId)) + case msg => + // shouldn't be any other message types, but just in case + log.debug( + "Stashing unexpected message [{}] while waiting for remember entities update of {}", + msg.getClass, + entityId) + stash() + } + } } def receiveLeaseLost(msg: LeaseLost): Unit = { // The shard region will re-create this when it receives a message for this shard - log.error("Shard type [{}] id [{}] lease lost. Reason: {}", typeName, shardId, msg.reason) + log.error( + "Shard type [{}] id [{}] lease lost, stopping shard and killing [{}] entities.{}", + typeName, + shardId, + entityIds.size, + msg.reason match { + case Some(reason) => s" Reason for losing lease: $reason" + case None => "" + }) // Stop entities ASAP rather than send termination message context.stop(self) } - private def receiveShardCommand(msg: ShardCommand): Unit = msg match { - // those are only used with remembering entities - case RestartEntity(id) => getOrCreateEntity(id) + private def receiveRememberEntityCommand(msg: RememberEntityCommand): Unit = msg match { + // these are only used with remembering entities upon start + case RestartEntity(id) => + // starting because it was remembered as started on shard startup (note that a message starting + // it up could already have arrived and in that case it will already be started) + getOrCreateEntity(id) case RestartEntities(ids) => restartEntities(ids) } @@ -312,32 +415,29 @@ private[akka] class Shard( log.debug("Got a request from [{}] to start entity [{}] in shard [{}]", requester, start.entityId, shardId) touchLastMessageTimestamp(start.entityId) - if (state.entities(start.entityId)) { + if (entityIds(start.entityId)) { getOrCreateEntity(start.entityId) requester ! ShardRegion.StartEntityAck(start.entityId, shardId) } else { - processChange(EntityStarted(start.entityId)) { evt => - getOrCreateEntity(start.entityId) - sendMsgBuffer(evt) - requester ! ShardRegion.StartEntityAck(start.entityId, shardId) + waitForAsyncWrite(start.entityId, RememberEntitiesShardStore.AddEntity(start.entityId)) { id => + getOrCreateEntity(id) + sendMsgBuffer(id) + requester ! ShardRegion.StartEntityAck(id, shardId) } } } private def receiveStartEntityAck(ack: ShardRegion.StartEntityAck): Unit = { - if (ack.shardId != shardId && state.entities.contains(ack.entityId)) { + if (ack.shardId != shardId && entityIds(ack.entityId)) { log.debug("Entity [{}] previously owned by shard [{}] started in shard [{}]", ack.entityId, shardId, ack.shardId) - processChange(EntityStopped(ack.entityId)) { _ => - state = state.copy(state.entities - ack.entityId) - messageBuffers.remove(ack.entityId) + + waitForAsyncWrite(ack.entityId, RememberEntitiesShardStore.RemoveEntity(ack.entityId)) { id => + entityIds = entityIds - id + messageBuffers.remove(id) } } } - private def restartEntities(ids: Set[EntityId]): Unit = { - context.actorOf(RememberEntityStarter.props(context.parent, ids, settings, sender())) - } - private def receiveShardRegionCommand(msg: ShardRegionCommand): Unit = msg match { case Passivate(stopMessage) => passivate(sender(), stopMessage) case _ => unhandled(msg) @@ -351,7 +451,7 @@ private[akka] class Shard( def receiveShardQuery(msg: ShardQuery): Unit = msg match { case GetCurrentShardState => sender() ! CurrentShardState(shardId, refById.keySet) - case GetShardStats => sender() ! ShardStats(shardId, state.entities.size) + case GetShardStats => sender() ! ShardStats(shardId, entityIds.size) } private def handOff(replyTo: ActorRef): Unit = handOffStopper match { @@ -385,6 +485,7 @@ private[akka] class Shard( @InternalStableApi def entityTerminated(ref: ActorRef): Unit = { + import settings.tuningParameters._ val id = idByRef(ref) idByRef -= ref refById -= id @@ -392,10 +493,20 @@ private[akka] class Shard( lastMessageTimestamp -= id } if (messageBuffers.getOrEmpty(id).nonEmpty) { + // Note; because we're not persisting the EntityStopped, we don't need + // to persist the EntityStarted either. log.debug("Starting entity [{}] again, there are buffered messages for it", id) - sendMsgBuffer(EntityStarted(id)) + sendMsgBuffer(id) } else { - processChange(EntityStopped(id))(passivateCompleted) + if (!passivating.contains(ref)) { + log.debug("Entity [{}] stopped without passivating, will restart after backoff", id) + // note that it's not removed from state here, will be started again via RestartEntity + import context.dispatcher + context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id)) + } else { + // FIXME optional wait for completion as optimization where stops are not critical + waitForAsyncWrite(id, RememberEntitiesShardStore.RemoveEntity(id))(passivateCompleted) + } } passivating = passivating - ref @@ -405,13 +516,15 @@ private[akka] class Shard( idByRef.get(entity) match { case Some(id) => if (!messageBuffers.contains(id)) { + if (VerboseDebug) + log.debug("Passivation started for {}", entity) passivating = passivating + entity messageBuffers.add(id) entity ! stopMessage } else { - log.debug("Passivation already in progress for {}. Not sending stopMessage back to entity.", entity) + log.debug("Passivation already in progress for {}. Not sending stopMessage back to entity", entity) } - case None => log.debug("Unknown entity {}. Not sending stopMessage back to entity.", entity) + case None => log.debug("Unknown entity {}. Not sending stopMessage back to entity", entity) } } @@ -432,40 +545,24 @@ private[akka] class Shard( } } - // EntityStopped handler - def passivateCompleted(event: EntityStopped): Unit = { - val hasBufferedMessages = messageBuffers.getOrEmpty(event.entityId).nonEmpty - state = state.copy(state.entities - event.entityId) + // After entity stopped + def passivateCompleted(entityId: EntityId): Unit = { + val hasBufferedMessages = messageBuffers.getOrEmpty(entityId).nonEmpty + entityIds = entityIds - entityId if (hasBufferedMessages) { - log.debug( - "Entity stopped after passivation [{}], but will be started again due to buffered messages.", - event.entityId) - processChange(EntityStarted(event.entityId))(sendMsgBuffer) + log.debug("Entity stopped after passivation [{}], but will be started again due to buffered messages.", entityId) + waitForAsyncWrite(entityId, RememberEntitiesShardStore.AddEntity(entityId))(sendMsgBuffer) } else { - log.debug("Entity stopped after passivation [{}]", event.entityId) - messageBuffers.remove(event.entityId) - } - - } - - // EntityStarted handler - def sendMsgBuffer(event: EntityStarted): Unit = { - //Get the buffered messages and remove the buffer - val messages = messageBuffers.getOrEmpty(event.entityId) - messageBuffers.remove(event.entityId) - - if (messages.nonEmpty) { - log.debug("Sending message buffer for entity [{}] ([{}] messages)", event.entityId, messages.size) - getOrCreateEntity(event.entityId) - //Now there is no deliveryBuffer we can try to redeliver - // and as the child exists, the message will be directly forwarded - messages.foreach { - case (msg, snd) => deliverMessage(msg, snd) - } + log.debug("Entity stopped after passivation [{}]", entityId) + dropBufferFor(entityId) } } - def deliverMessage(msg: Any, snd: ActorRef): Unit = { + /** + * @param entityIdWaitingForWrite an id for an remember entity write in progress, if non empty messages for that id + * will be buffered + */ + private def deliverMessage(msg: Any, snd: ActorRef, entityIdWaitingForWrite: OptionVal[EntityId]): Unit = { val (id, payload) = extractEntityId(msg) if (id == null || id == "") { log.warning("Id must not be empty, dropping message [{}]", msg.getClass.getName) @@ -473,553 +570,136 @@ private[akka] class Shard( } else { payload match { case start: ShardRegion.StartEntity => - // in case it was wrapped, used in Typed - receiveStartEntity(start) + // we can only start a new entity if we are not currently waiting for another write + if (entityIdWaitingForWrite.isEmpty) receiveStartEntity(start) + // write in progress, see waitForAsyncWrite for unstash + else stash() case _ => - if (messageBuffers.contains(id)) + if (messageBuffers.contains(id) || entityIdWaitingForWrite.contains(id)) { + // either: + // 1. entity is passivating, buffer until passivation complete (id in message buffers) + // 2. we are waiting for storing entity start or stop with remember entities to complete + // and want to buffer until write completes + if (VerboseDebug) { + if (entityIdWaitingForWrite.contains(id)) + log.debug("Buffering message [{}] to [{}] because of write in progress for it", msg.getClass, id) + else + log.debug("Buffering message [{}] to [{}] because passivation in progress for it", msg.getClass, id) + } appendToMessageBuffer(id, msg, snd) - else - deliverTo(id, msg, payload, snd) - } - } - } - - def appendToMessageBuffer(id: EntityId, msg: Any, snd: ActorRef): Unit = { - if (messageBuffers.totalSize >= bufferSize) { - log.debug("Buffer is full, dropping message for entity [{}]", id) - context.system.deadLetters ! msg - } else { - log.debug("Message for entity [{}] buffered", id) - messageBuffers.append(id, msg, snd) - } - } - - def deliverTo(id: EntityId, @unused msg: Any, payload: Msg, snd: ActorRef): Unit = { - touchLastMessageTimestamp(id) - getOrCreateEntity(id).tell(payload, snd) - } - - @InternalStableApi - def getOrCreateEntity(id: EntityId): ActorRef = { - val name = URLEncoder.encode(id, "utf-8") - context.child(name) match { - case Some(child) => child - case None => - log.debug("Starting entity [{}] in shard [{}]", id, shardId) - val a = context.watch(context.actorOf(entityProps(id), name)) - idByRef = idByRef.updated(a, id) - refById = refById.updated(id, a) - state = state.copy(state.entities + id) - touchLastMessageTimestamp(id) - a - } - } - - override def postStop(): Unit = { - passivateIdleTask.foreach(_.cancel()) - } -} - -private[akka] object RememberEntityStarter { - def props(region: ActorRef, ids: Set[ShardRegion.EntityId], settings: ClusterShardingSettings, requestor: ActorRef) = - Props(new RememberEntityStarter(region, ids, settings, requestor)) - - private case object Tick extends NoSerializationVerificationNeeded -} - -/** - * INTERNAL API: Actor responsible for starting entities when rememberEntities is enabled - */ -private[akka] class RememberEntityStarter( - region: ActorRef, - ids: Set[ShardRegion.EntityId], - settings: ClusterShardingSettings, - requestor: ActorRef) - extends Actor - with ActorLogging { - - import RememberEntityStarter.Tick - import context.dispatcher - - var waitingForAck = ids - - sendStart(ids) - - val tickTask = { - val resendInterval = settings.tuningParameters.retryInterval - context.system.scheduler.scheduleWithFixedDelay(resendInterval, resendInterval, self, Tick) - } - - def sendStart(ids: Set[ShardRegion.EntityId]): Unit = { - // these go through the region rather the directly to the shard - // so that shard mapping changes are picked up - ids.foreach(id => region ! ShardRegion.StartEntity(id)) - } - - override def receive: Receive = { - case ack: ShardRegion.StartEntityAck => - waitingForAck -= ack.entityId - // inform whoever requested the start that it happened - requestor ! ack - if (waitingForAck.isEmpty) context.stop(self) - - case Tick => - sendStart(waitingForAck) - - } - - override def postStop(): Unit = { - tickTask.cancel() - } -} - -/** - * INTERNAL API: Common things for PersistentShard and DDataShard - */ -private[akka] trait RememberingShard { - selfType: Shard => - - import Shard._ - import ShardRegion.EntityId - import ShardRegion.Msg - import akka.pattern.pipe - - protected val settings: ClusterShardingSettings - - protected val rememberedEntitiesRecoveryStrategy: EntityRecoveryStrategy = { - import settings.tuningParameters._ - entityRecoveryStrategy match { - case "all" => EntityRecoveryStrategy.allStrategy() - case "constant" => - EntityRecoveryStrategy.constantStrategy( - context.system, - entityRecoveryConstantRateStrategyFrequency, - entityRecoveryConstantRateStrategyNumberOfEntities) - } - } - - protected def restartRememberedEntities(): Unit = { - rememberedEntitiesRecoveryStrategy.recoverEntities(state.entities).foreach { scheduledRecovery => - import context.dispatcher - scheduledRecovery.filter(_.nonEmpty).map(RestartEntities).pipeTo(self) - } - } - - override def entityTerminated(ref: ActorRef): Unit = { - import settings.tuningParameters._ - val id = idByRef(ref) - idByRef -= ref - refById -= id - if (passivateIdleTask.isDefined) { - lastMessageTimestamp -= id - } - if (messageBuffers.getOrEmpty(id).nonEmpty) { - //Note; because we're not persisting the EntityStopped, we don't need - // to persist the EntityStarted either. - log.debug("Starting entity [{}] again, there are buffered messages for it", id) - sendMsgBuffer(EntityStarted(id)) - } else { - if (!passivating.contains(ref)) { - log.debug("Entity [{}] stopped without passivating, will restart after backoff", id) - // note that it's not removed from state here, will be started again via RestartEntity - import context.dispatcher - context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id)) - } else processChange(EntityStopped(id))(passivateCompleted) - } - - passivating = passivating - ref - } - - override def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = { - val name = URLEncoder.encode(id, "utf-8") - context.child(name) match { - case Some(actor) => - touchLastMessageTimestamp(id) - actor.tell(payload, snd) - case None => - if (state.entities.contains(id)) { - // this may happen when entity is stopped without passivation - require(!messageBuffers.contains(id), s"Message buffers contains id [$id].") - getOrCreateEntity(id).tell(payload, snd) - } else { - //Note; we only do this if remembering, otherwise the buffer is an overhead - messageBuffers.append(id, msg, snd) - processChange(EntityStarted(id))(sendMsgBuffer) - } - } - } -} - -/** - * INTERNAL API - * - * This actor creates children entity actors on demand that it is told to be - * responsible for. It is used when `rememberEntities` is enabled and - * `state-store-mode=persistence`. - * - * @see [[ClusterSharding$ ClusterSharding extension]] - */ -private[akka] class PersistentShard( - typeName: String, - shardId: ShardRegion.ShardId, - entityProps: String => Props, - override val settings: ClusterShardingSettings, - extractEntityId: ShardRegion.ExtractEntityId, - extractShardId: ShardRegion.ExtractShardId, - handOffStopMessage: Any) - extends Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage) - with RememberingShard - with PersistentActor - with ActorLogging { - - import Shard._ - import settings.tuningParameters._ - - override def preStart(): Unit = { - // override to not acquire the lease on start up, acquire after persistent recovery - } - - override def persistenceId = s"/sharding/${typeName}Shard/$shardId" - - override def journalPluginId: String = settings.journalPluginId - - override def snapshotPluginId: String = settings.snapshotPluginId - - override def receive = receiveCommand - - override def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = { - saveSnapshotWhenNeeded() - persist(event)(handler) - } - - def saveSnapshotWhenNeeded(): Unit = { - if (lastSequenceNr % snapshotAfter == 0 && lastSequenceNr != 0) { - log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr) - saveSnapshot(state) - } - } - - override def receiveRecover: Receive = { - case EntityStarted(id) => state = state.copy(state.entities + id) - case EntityStopped(id) => state = state.copy(state.entities - id) - case SnapshotOffer(_, snapshot: State) => state = snapshot - case RecoveryCompleted => - acquireLeaseIfNeeded() // onLeaseAcquired called when completed - log.debug("PersistentShard recovery completed shard [{}] with [{}] entities", shardId, state.entities.size) - } - - override def onLeaseAcquired(): Unit = { - log.debug("Shard initialized") - context.parent ! ShardInitialized(shardId) - context.become(receiveCommand) - restartRememberedEntities() - unstashAll() - } - - override def receiveCommand: Receive = - ({ - case e: SaveSnapshotSuccess => - log.debug("PersistentShard snapshot saved successfully") - internalDeleteMessagesBeforeSnapshot(e, keepNrOfBatches, snapshotAfter) - - case SaveSnapshotFailure(_, reason) => - log.warning("PersistentShard snapshot failure: [{}]", reason.getMessage) - - case DeleteMessagesSuccess(toSequenceNr) => - val deleteTo = toSequenceNr - 1 - val deleteFrom = math.max(0, deleteTo - (keepNrOfBatches * snapshotAfter)) - log.debug( - "PersistentShard messages to [{}] deleted successfully. Deleting snapshots from [{}] to [{}]", - toSequenceNr, - deleteFrom, - deleteTo) - deleteSnapshots(SnapshotSelectionCriteria(minSequenceNr = deleteFrom, maxSequenceNr = deleteTo)) - - case DeleteMessagesFailure(reason, toSequenceNr) => - log.warning("PersistentShard messages to [{}] deletion failure: [{}]", toSequenceNr, reason.getMessage) - - case DeleteSnapshotsSuccess(m) => - log.debug("PersistentShard snapshots matching [{}] deleted successfully", m) - - case DeleteSnapshotsFailure(m, reason) => - log.warning("PersistentShard snapshots matching [{}] deletion failure: [{}]", m, reason.getMessage) - - }: Receive).orElse(super.receiveCommand) - -} - -/** - * INTERNAL API - * - * This actor creates children entity actors on demand that it is told to be - * responsible for. It is used when `rememberEntities` is enabled and - * `state-store-mode=ddata`. - * - * @see [[ClusterSharding$ ClusterSharding extension]] - */ -private[akka] class DDataShard( - typeName: String, - shardId: ShardRegion.ShardId, - entityProps: String => Props, - override val settings: ClusterShardingSettings, - extractEntityId: ShardRegion.ExtractEntityId, - extractShardId: ShardRegion.ExtractShardId, - handOffStopMessage: Any, - replicator: ActorRef, - majorityMinCap: Int) - extends Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage) - with RememberingShard - with Stash - with ActorLogging { - - import Shard._ - import ShardRegion.EntityId - import settings.tuningParameters._ - - private val readMajority = ReadMajority(settings.tuningParameters.waitingForStateTimeout, majorityMinCap) - private val writeMajority = WriteMajority(settings.tuningParameters.updatingStateTimeout, majorityMinCap) - private val maxUpdateAttempts = 3 - - implicit private val node = Cluster(context.system) - implicit private val selfUniqueAddress = SelfUniqueAddress(node.selfUniqueAddress) - - // The default maximum-frame-size is 256 KiB with Artery. - // When using entity identifiers with 36 character strings (e.g. UUID.randomUUID). - // By splitting the elements over 5 keys we can support 10000 entities per shard. - // The Gossip message size of 5 ORSet with 2000 ids is around 200 KiB. - // This is by intention not configurable because it's important to have the same - // configuration on each node. - private val numberOfKeys = 5 - private val stateKeys: Array[ORSetKey[EntityId]] = - Array.tabulate(numberOfKeys)(i => ORSetKey[EntityId](s"shard-$typeName-$shardId-$i")) - - private var waiting = true - - private def key(entityId: EntityId): ORSetKey[EntityId] = { - val i = math.abs(entityId.hashCode % numberOfKeys) - stateKeys(i) - } - - override def onLeaseAcquired(): Unit = { - log.info("Lease Acquired. Getting state from DData") - getState() - context.become(waitingForState(Set.empty)) - } - - private def getState(): Unit = { - (0 until numberOfKeys).foreach { i => - replicator ! Get(stateKeys(i), readMajority, Some(i)) - } - } - - override protected[akka] def aroundReceive(rcv: Receive, msg: Any): Unit = { - super.aroundReceive(rcv, msg) - if (!waiting) - unstash() // unstash one message - } - - override def receive = waitingForState(Set.empty) - - // This state will stash all commands - private def waitingForState(gotKeys: Set[Int]): Receive = { - def receiveOne(i: Int): Unit = { - val newGotKeys = gotKeys + i - if (newGotKeys.size == numberOfKeys) { - recoveryCompleted() - } else - context.become(waitingForState(newGotKeys)) - } - - { - case g @ GetSuccess(_, Some(i: Int)) => - val key = stateKeys(i) - state = state.copy(entities = state.entities.union(g.get(key).elements)) - receiveOne(i) - - case GetFailure(_, _) => - log.error( - "The DDataShard was unable to get an initial state within 'waiting-for-state-timeout': {} millis", - waitingForStateTimeout.toMillis) - // parent ShardRegion supervisor will notice that it terminated and will start it again, after backoff - context.stop(self) - - case NotFound(_, Some(i: Int)) => - receiveOne(i) - - case _ => - log.debug("Stashing while waiting for DDataShard initial state") - stash() - } - } - - private def recoveryCompleted(): Unit = { - log.debug("DDataShard recovery completed shard [{}] with [{}] entities", shardId, state.entities.size) - waiting = false - context.parent ! ShardInitialized(shardId) - context.become(receiveCommand) - restartRememberedEntities() - } - - override def processChange[E <: StateChange](event: E)(handler: E => Unit): Unit = { - waiting = true - context.become(waitingForUpdate(event, handler), discardOld = false) - sendUpdate(event, retryCount = 1) - } - - private def sendUpdate(evt: StateChange, retryCount: Int): Unit = { - replicator ! Update(key(evt.entityId), ORSet.empty[EntityId], writeMajority, Some((evt, retryCount))) { existing => - evt match { - case EntityStarted(id) => existing :+ id - case EntityStopped(id) => existing.remove(id) - } - } - } - - // this state will stash all messages until it receives UpdateSuccess - private def waitingForUpdate[E <: StateChange](evt: E, afterUpdateCallback: E => Unit): Receive = { - case UpdateSuccess(_, Some((`evt`, _))) => - log.debug("The DDataShard state was successfully updated with {}", evt) - waiting = false - context.unbecome() - afterUpdateCallback(evt) - - case UpdateTimeout(_, Some((`evt`, retryCount: Int))) => - if (retryCount == maxUpdateAttempts) { - // parent ShardRegion supervisor will notice that it terminated and will start it again, after backoff - log.error( - "The DDataShard was unable to update state after {} attempts, within 'updating-state-timeout'={} millis, event={}. " + - "Shard will be restarted after backoff.", - maxUpdateAttempts, - updatingStateTimeout.toMillis, - evt) - context.stop(self) - } else { - log.warning( - "The DDataShard was unable to update state, attempt {} of {}, within 'updating-state-timeout'={} millis, event={}", - retryCount, - maxUpdateAttempts, - updatingStateTimeout.toMillis, - evt) - sendUpdate(evt, retryCount + 1) - } - - case StoreFailure(_, Some((`evt`, _))) => - log.error( - "The DDataShard was unable to update state with event {} due to StoreFailure. " + - "Shard will be restarted after backoff.", - evt) - context.stop(self) - - case ModifyFailure(_, error, cause, Some((`evt`, _))) => - log.error( - cause, - "The DDataShard was unable to update state with event {} due to ModifyFailure. " + - "Shard will be restarted. {}", - evt, - error) - throw cause - - // below cases should handle same messages as in Shard.receiveCommand - case _: Terminated => stash() - case _: CoordinatorMessage => stash() - case _: ShardCommand => stash() - case _: ShardRegion.StartEntity => stash() - case _: ShardRegion.StartEntityAck => stash() - case _: ShardRegionCommand => stash() - case msg: ShardQuery => receiveShardQuery(msg) - case PassivateIdleTick => stash() - case msg: LeaseLost => receiveLeaseLost(msg) - case msg if extractEntityId.isDefinedAt(msg) => deliverOrBufferMessage(msg, evt) - case msg => - // shouldn't be any other message types, but just in case - log.debug("Stashing unexpected message [{}] while waiting for DDataShard update of {}", msg.getClass, evt) - stash() - } - - /** - * If the message is for the same entity as we are waiting for the update it will be added to - * its messageBuffer, which will be sent after the update has completed. - * - * If the message is for another entity that is already started (and not in progress of passivating) - * it will be delivered immediately. - * - * Otherwise it will be stashed, and processed after the update has been completed. - */ - private def deliverOrBufferMessage(msg: Any, waitingForUpdateEvent: StateChange): Unit = { - val (id, payload) = extractEntityId(msg) - if (id == null || id == "") { - log.warning("Id must not be empty, dropping message [{}]", msg.getClass.getName) - context.system.deadLetters ! msg - } else { - payload match { - case _: ShardRegion.StartEntity => - // in case it was wrapped, used in Typed - stash() - case _ => - if (id == waitingForUpdateEvent.entityId) { - appendToMessageBuffer(id, msg, sender()) } else { - val name = URLEncoder.encode(id, "utf-8") - // messageBuffers.contains(id) when passivation is in progress - if (!messageBuffers.contains(id) && context.child(name).nonEmpty) { - deliverTo(id, msg, payload, sender()) - } else { - log.debug("Stashing to [{}] while waiting for DDataShard update of {}", id, waitingForUpdateEvent) - stash() + // With remember entities enabled we may be in the process of saving that we are starting up the entity + // and in that case we need to buffer messages until that completes + refById.get(id) match { + case Some(actor) => + // not using remember entities or write is in progress for other entity and this entity is running already + // go ahead and deliver + if (VerboseDebug) log.debug("Delivering message of type [{}] to [{}]", payload.getClass, id) + touchLastMessageTimestamp(id) + actor.tell(payload, snd) + case None => + if (entityIds.contains(id)) { + // No entity actor running but id is in set of entities, this can happen in two scenarios: + // 1. we are starting up and the entity id is remembered but not yet started + // 2. the entity is stopped but not passivated, and should be restarted + + // FIXME won't this potentially lead to not remembered for case 2? + if (VerboseDebug) + log.debug( + "Delivering message of type [{}] to [{}] (starting because known but not running)", + payload.getClass, + id) + val actor = getOrCreateEntity(id) + touchLastMessageTimestamp(id) + actor.tell(payload, snd) + + } else { + if (entityIdWaitingForWrite.isEmpty) { + // No actor and id is unknown, start actor and deliver message when started + // Note; we only do this if remembering, otherwise the buffer is an overhead + if (VerboseDebug) + log.debug("Buffering message [{}] to [{}] and starting actor", payload.getClass, id) + appendToMessageBuffer(id, msg, snd) + waitForAsyncWrite(id, RememberEntitiesShardStore.AddEntity(id))(sendMsgBuffer) + } else { + // we'd need to start the entity but a start/stop write is already in progress + // see waitForAsyncWrite for unstash + if (VerboseDebug) + log.debug( + "Stashing message [{}] to [{}] because of write in progress for [{}]", + payload.getClass, + id, + entityIdWaitingForWrite.get) + stash() + } + } } } } } } -} + @InternalStableApi + def getOrCreateEntity(id: EntityId): ActorRef = { + refById.get(id) match { + case Some(child) => child + case None => + val name = URLEncoder.encode(id, "utf-8") + val a = context.watch(context.actorOf(entityProps(id), name)) + log.debug("Started entity [{}] with entity id [{}] in shard [{}]", a, id, shardId) + idByRef = idByRef.updated(a, id) + refById = refById.updated(id, a) + entityIds = entityIds + id + touchLastMessageTimestamp(id) + a + } + } -object EntityRecoveryStrategy { - def allStrategy(): EntityRecoveryStrategy = new AllAtOnceEntityRecoveryStrategy() + // ===== buffering while busy saving a start or stop when remembering entities ===== + def appendToMessageBuffer(id: EntityId, msg: Any, snd: ActorRef): Unit = { + if (messageBuffers.totalSize >= bufferSize) { + if (log.isDebugEnabled) + log.debug("Buffer is full, dropping message of type [{}] for entity [{}]", msg.getClass.getName, id) + context.system.deadLetters ! msg + } else { + if (log.isDebugEnabled) + log.debug("Message of type [{}] for entity [{}] buffered", msg.getClass.getName, id) + messageBuffers.append(id, msg, snd) + } + } - def constantStrategy( - actorSystem: ActorSystem, - frequency: FiniteDuration, - numberOfEntities: Int): EntityRecoveryStrategy = - new ConstantRateEntityRecoveryStrategy(actorSystem, frequency, numberOfEntities) -} + // After entity started + def sendMsgBuffer(entityId: EntityId): Unit = { + //Get the buffered messages and remove the buffer + val messages = messageBuffers.getOrEmpty(entityId) + messageBuffers.remove(entityId) -trait EntityRecoveryStrategy { - - import scala.concurrent.Future - - import ShardRegion.EntityId - - def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] -} - -final class AllAtOnceEntityRecoveryStrategy extends EntityRecoveryStrategy { - - import ShardRegion.EntityId - - override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = - if (entities.isEmpty) Set.empty else Set(Future.successful(entities)) -} - -final class ConstantRateEntityRecoveryStrategy( - actorSystem: ActorSystem, - frequency: FiniteDuration, - numberOfEntities: Int) - extends EntityRecoveryStrategy { - - import ShardRegion.EntityId - import actorSystem.dispatcher - import akka.pattern.after - - override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = - entities - .grouped(numberOfEntities) - .foldLeft((frequency, Set[Future[Set[EntityId]]]())) { - case ((interval, scheduledEntityIds), entityIds) => - (interval + frequency, scheduledEntityIds + scheduleEntities(interval, entityIds)) + if (messages.nonEmpty) { + getOrCreateEntity(entityId) + log.debug("Sending message buffer for entity [{}] ([{}] messages)", entityId, messages.size) + //Now there is no deliveryBuffer we can try to redeliver + // and as the child exists, the message will be directly forwarded + messages.foreach { + case (msg, snd) => deliverMessage(msg, snd, OptionVal.None) } - ._2 + touchLastMessageTimestamp(entityId) + } + } - private def scheduleEntities(interval: FiniteDuration, entityIds: Set[EntityId]): Future[Set[EntityId]] = - after(interval, actorSystem.scheduler)(Future.successful[Set[EntityId]](entityIds)) + def dropBufferFor(entityId: EntityId): Unit = { + if (log.isDebugEnabled) { + val messages = messageBuffers.getOrEmpty(entityId) + if (messages.nonEmpty) + log.debug("Dropping [{}] buffered messages", entityId, messages.size) + } + messageBuffers.remove(entityId) + } + + private def rememberEntityStoreCrashed(msg: RememberEntityStoreCrashed): Unit = { + throw new RuntimeException(s"Remember entities store [${msg.store}] crashed") + } + + override def postStop(): Unit = { + passivateIdleTask.foreach(_.cancel()) + } } diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala index 8756aa27af..4b855cc71b 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala @@ -9,21 +9,33 @@ import java.net.URLEncoder import akka.Done import akka.actor._ import akka.annotation.InternalApi +import akka.cluster.Cluster import akka.cluster.ClusterEvent._ +import akka.cluster.ClusterSettings import akka.cluster.ClusterSettings.DataCenter +import akka.cluster.Member +import akka.cluster.MemberStatus import akka.cluster.sharding.Shard.ShardStats -import akka.cluster.{ Cluster, ClusterSettings, Member, MemberStatus } +import akka.cluster.sharding.internal.CustomStateStoreModeProvider +import akka.cluster.sharding.internal.DDataRememberEntitiesShardStoreProvider +import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStoreProvider +import akka.cluster.sharding.internal.RememberEntitiesShardStoreProvider import akka.event.Logging -import akka.pattern.{ ask, pipe } -import akka.util.{ MessageBufferMap, PrettyDuration, Timeout } +import akka.pattern.ask +import akka.pattern.pipe +import akka.util.MessageBufferMap +import akka.util.PrettyDuration +import akka.util.Timeout import scala.annotation.tailrec import scala.collection.immutable +import scala.concurrent.Future +import scala.concurrent.Promise import scala.concurrent.duration._ -import scala.concurrent.{ Future, Promise } import scala.reflect.ClassTag import scala.runtime.AbstractFunction1 -import scala.util.{ Failure, Success } +import scala.util.Failure +import scala.util.Success /** * @see [[ClusterSharding$ ClusterSharding extension]] @@ -499,6 +511,7 @@ object ShardRegion { stopMessage: Any, handoffTimeout: FiniteDuration): Props = Props(new HandOffStopper(shard, replyTo, entities, stopMessage, handoffTimeout)).withDeploy(Deploy.local) + } /** @@ -557,6 +570,19 @@ private[akka] class ShardRegion( val initRegistrationDelay: FiniteDuration = 100.millis.max(retryInterval / 2 / 2 / 2) var nextRegistrationDelay: FiniteDuration = initRegistrationDelay + val shardRememberEntitiesStoreProvider: Option[RememberEntitiesShardStoreProvider] = + if (!settings.rememberEntities) None + else + // this construction will move upwards when we get to refactoring the coordinator + Some(settings.stateStoreMode match { + case ClusterShardingSettings.StateStoreModeDData => + new DDataRememberEntitiesShardStoreProvider(typeName, settings, replicator, majorityMinCap) + case ClusterShardingSettings.StateStoreModePersistence => + new EventSourcedRememberEntitiesStoreProvider(typeName, settings) + case ClusterShardingSettings.StateStoreModeCustom => + new CustomStateStoreModeProvider(typeName, context.system, settings) + }) + // for CoordinatedShutdown val gracefulShutdownProgress = Promise[Done]() CoordinatedShutdown(context.system) @@ -1101,6 +1127,7 @@ private[akka] class ShardRegion( log.debug(ShardingLogMarker.shardStarted(typeName, id), "{}: Starting shard [{}] in region", typeName, id) val name = URLEncoder.encode(id, "utf-8") + val shard = context.watch( context.actorOf( Shard @@ -1112,8 +1139,7 @@ private[akka] class ShardRegion( extractEntityId, extractShardId, handOffStopMessage, - replicator, - majorityMinCap) + shardRememberEntitiesStoreProvider) .withDispatcher(context.props.dispatcher), name)) shardsByRef = shardsByRef.updated(shard, id) diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/CustomStateStoreModeProvider.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/CustomStateStoreModeProvider.scala new file mode 100644 index 0000000000..5a2c71fef9 --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/CustomStateStoreModeProvider.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal +import akka.actor.ActorSystem +import akka.actor.ExtendedActorSystem +import akka.actor.Props +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.ShardRegion.ShardId +import akka.event.Logging + +/** + * INTERNAL API + * + * Only intended for testing, not an extension point. + */ +private[akka] final class CustomStateStoreModeProvider( + typeName: String, + system: ActorSystem, + settings: ClusterShardingSettings) + extends RememberEntitiesShardStoreProvider { + + private val log = Logging(system, getClass) + log.warning("Using custom remember entities store for [{}], not intended for production use.", typeName) + val customStore = if (system.settings.config.hasPath("akka.cluster.sharding.custom-store")) { + val customClassName = system.settings.config.getString("akka.cluster.sharding.custom-store") + + val store = system + .asInstanceOf[ExtendedActorSystem] + .dynamicAccess + .createInstanceFor[RememberEntitiesShardStoreProvider]( + customClassName, + Vector((classOf[ClusterShardingSettings], settings), (classOf[String], typeName))) + log.debug("Will use custom remember entities store provider [{}]", store) + store.get + + } else { + log.error("Missing custom store class configuration for CustomStateStoreModeProvider") + throw new RuntimeException("Missing custom store class configuration") + } + + override def shardStoreProps(shardId: ShardId): Props = customStore.shardStoreProps(shardId) + +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/DDataRememberEntitiesStore.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/DDataRememberEntitiesStore.scala new file mode 100644 index 0000000000..1f385e591d --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/DDataRememberEntitiesStore.scala @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.ActorRef +import akka.actor.Props +import akka.annotation.InternalApi +import akka.cluster.Cluster +import akka.cluster.ddata.ORSet +import akka.cluster.ddata.ORSetKey +import akka.cluster.ddata.Replicator.Get +import akka.cluster.ddata.Replicator.GetDataDeleted +import akka.cluster.ddata.Replicator.GetFailure +import akka.cluster.ddata.Replicator.GetSuccess +import akka.cluster.ddata.Replicator.ModifyFailure +import akka.cluster.ddata.Replicator.NotFound +import akka.cluster.ddata.Replicator.ReadMajority +import akka.cluster.ddata.Replicator.StoreFailure +import akka.cluster.ddata.Replicator.Update +import akka.cluster.ddata.Replicator.UpdateDataDeleted +import akka.cluster.ddata.Replicator.UpdateSuccess +import akka.cluster.ddata.Replicator.UpdateTimeout +import akka.cluster.ddata.Replicator.WriteMajority +import akka.cluster.ddata.SelfUniqueAddress +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.ShardRegion.ShardId +import akka.util.PrettyDuration._ + +import scala.concurrent.ExecutionContext + +/** + * INTERNAL API + */ +@InternalApi +private[akka] final class DDataRememberEntitiesShardStoreProvider( + typeName: String, + settings: ClusterShardingSettings, + replicator: ActorRef, + majorityMinCap: Int) + extends RememberEntitiesShardStoreProvider { + + override def shardStoreProps(shardId: ShardId): Props = + DDataRememberEntitiesStore.props(shardId, typeName, settings, replicator, majorityMinCap) + +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object DDataRememberEntitiesStore { + + def props( + shardId: ShardId, + typeName: String, + settings: ClusterShardingSettings, + replicator: ActorRef, + majorityMinCap: Int): Props = + Props(new DDataRememberEntitiesStore(shardId, typeName, settings, replicator, majorityMinCap)) + + // The default maximum-frame-size is 256 KiB with Artery. + // When using entity identifiers with 36 character strings (e.g. UUID.randomUUID). + // By splitting the elements over 5 keys we can support 10000 entities per shard. + // The Gossip message size of 5 ORSet with 2000 ids is around 200 KiB. + // This is by intention not configurable because it's important to have the same + // configuration on each node. + private val numberOfKeys = 5 + + private def stateKeys(typeName: String, shardId: ShardId): Array[ORSetKey[EntityId]] = + Array.tabulate(numberOfKeys)(i => ORSetKey[EntityId](s"shard-$typeName-$shardId-$i")) + +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] final class DDataRememberEntitiesStore( + shardId: ShardId, + typeName: String, + settings: ClusterShardingSettings, + replicator: ActorRef, + majorityMinCap: Int) + extends Actor + with ActorLogging { + + import DDataRememberEntitiesStore._ + + implicit val ec: ExecutionContext = context.dispatcher + implicit val node: Cluster = Cluster(context.system) + implicit val selfUniqueAddress: SelfUniqueAddress = SelfUniqueAddress(node.selfUniqueAddress) + + private val readMajority = ReadMajority(settings.tuningParameters.waitingForStateTimeout, majorityMinCap) + // Note that the timeout is actually updatingStateTimeout x 3 since we do 3 retries + private val writeMajority = WriteMajority(settings.tuningParameters.updatingStateTimeout, majorityMinCap) + private val maxUpdateAttempts = 3 + private val keys = stateKeys(typeName, shardId) + + if (log.isDebugEnabled) { + log.debug( + "Starting up DDataRememberEntitiesStore, write timeout: [{}], read timeout: [{}], majority min cap: [{}]", + settings.tuningParameters.waitingForStateTimeout.pretty, + settings.tuningParameters.updatingStateTimeout.pretty, + majorityMinCap) + } + // FIXME potential optimization: start loading entity ids immediately on start instead of waiting for request + // (then throw away after request has been seen) + + private def key(entityId: EntityId): ORSetKey[EntityId] = { + val i = math.abs(entityId.hashCode % numberOfKeys) + keys(i) + } + + override def receive: Receive = idle + + def idle: Receive = { + case update: RememberEntitiesShardStore.UpdateEntityCommand => onUpdate(update) + case RememberEntitiesShardStore.GetEntities => onGetEntities() + } + + def waitingForAllEntityIds(requestor: ActorRef, gotKeys: Set[Int], ids: Set[EntityId]): Receive = { + def receiveOne(i: Int, idsForKey: Set[EntityId]): Unit = { + val newGotKeys = gotKeys + i + val newIds = ids.union(idsForKey) + if (newGotKeys.size == numberOfKeys) { + requestor ! RememberEntitiesShardStore.RememberedEntities(newIds) + context.become(idle) + } else { + context.become(waitingForAllEntityIds(requestor, newGotKeys, newIds)) + } + } + + { + case g @ GetSuccess(_, Some(i: Int)) => + val key = keys(i) + val ids = g.get(key).elements + receiveOne(i, ids) + case NotFound(_, Some(i: Int)) => + receiveOne(i, Set.empty) + case GetFailure(key, _) => + log.error( + "Unable to get an initial state within 'waiting-for-state-timeout': [{}] using [{}] (key [{}])", + readMajority.timeout.pretty, + readMajority, + key) + context.stop(self) + case GetDataDeleted(_, _) => + log.error("Unable to get an initial state because it was deleted") + context.stop(self) + } + } + + private def onUpdate(update: RememberEntitiesShardStore.UpdateEntityCommand): Unit = { + val keyForEntity = key(update.entityId) + val sendUpdate = () => + replicator ! Update(keyForEntity, ORSet.empty[EntityId], writeMajority, Some(update)) { existing => + update match { + case RememberEntitiesShardStore.AddEntity(id) => existing :+ id + case RememberEntitiesShardStore.RemoveEntity(id) => existing.remove(id) + } + } + + sendUpdate() + context.become(waitingForUpdate(sender(), update, keyForEntity, maxUpdateAttempts, sendUpdate)) + } + + private def waitingForUpdate( + requestor: ActorRef, + update: RememberEntitiesShardStore.UpdateEntityCommand, + keyForEntity: ORSetKey[EntityId], + retriesLeft: Int, + retry: () => Unit): Receive = { + case UpdateSuccess(`keyForEntity`, Some(`update`)) => + log.debug("The DDataShard state was successfully updated for [{}]", update.entityId) + requestor ! RememberEntitiesShardStore.UpdateDone(update.entityId) + context.become(idle) + + case UpdateTimeout(`keyForEntity`, Some(`update`)) => + if (retriesLeft > 0) { + log.debug("Retrying update because of write timeout, tries left [{}]", retriesLeft) + retry() + } else { + log.error( + "Unable to update state, within 'updating-state-timeout'= [{}], gave up after [{}] retries", + writeMajority.timeout.pretty, + maxUpdateAttempts) + // will trigger shard restart + context.stop(self) + } + case StoreFailure(`keyForEntity`, Some(`update`)) => + log.error("Unable to update state, due to store failure") + // will trigger shard restart + context.stop(self) + case ModifyFailure(`keyForEntity`, error, cause, Some(`update`)) => + log.error(cause, "Unable to update state, due to modify failure: {}", error) + // will trigger shard restart + context.stop(self) + case UpdateDataDeleted(`keyForEntity`, Some(`update`)) => + log.error("Unable to update state, due to delete") + // will trigger shard restart + context.stop(self) + } + + private def onGetEntities(): Unit = { + (0 until numberOfKeys).toSet[Int].foreach { i => + val key = keys(i) + replicator ! Get(key, readMajority, Some(i)) + } + context.become(waitingForAllEntityIds(sender(), Set.empty, Set.empty)) + } + +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala new file mode 100644 index 0000000000..b00a756bbc --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.ActorSystem +import akka.annotation.InternalApi +import akka.cluster.sharding.ShardRegion +import akka.util.PrettyDuration + +import scala.collection.immutable.Set +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object EntityRecoveryStrategy { + def allStrategy(): EntityRecoveryStrategy = new AllAtOnceEntityRecoveryStrategy() + + def constantStrategy( + actorSystem: ActorSystem, + frequency: FiniteDuration, + numberOfEntities: Int): EntityRecoveryStrategy = + new ConstantRateEntityRecoveryStrategy(actorSystem, frequency, numberOfEntities) +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] trait EntityRecoveryStrategy { + + import ShardRegion.EntityId + + import scala.concurrent.Future + + def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] final class AllAtOnceEntityRecoveryStrategy extends EntityRecoveryStrategy { + + import ShardRegion.EntityId + + override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = + if (entities.isEmpty) Set.empty else Set(Future.successful(entities)) + + override def toString: EntityId = "AllAtOnceEntityRecoveryStrategy" +} + +final class ConstantRateEntityRecoveryStrategy( + actorSystem: ActorSystem, + frequency: FiniteDuration, + numberOfEntities: Int) + extends EntityRecoveryStrategy { + + import ShardRegion.EntityId + import actorSystem.dispatcher + import akka.pattern.after + + override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = + entities + .grouped(numberOfEntities) + .foldLeft((frequency, Set[Future[Set[EntityId]]]())) { + case ((interval, scheduledEntityIds), entityIds) => + (interval + frequency, scheduledEntityIds + scheduleEntities(interval, entityIds)) + } + ._2 + + private def scheduleEntities(interval: FiniteDuration, entityIds: Set[EntityId]): Future[Set[EntityId]] = + after(interval, actorSystem.scheduler)(Future.successful[Set[EntityId]](entityIds)) + + override def toString: EntityId = + s"ConstantRateEntityRecoveryStrategy(${PrettyDuration.format(frequency)}, $numberOfEntities)" +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EventSourcedRememberEntities.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EventSourcedRememberEntities.scala new file mode 100644 index 0000000000..6478631d75 --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EventSourcedRememberEntities.scala @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.ActorLogging +import akka.actor.Props +import akka.annotation.InternalApi +import akka.cluster.sharding.ClusterShardingSerializable +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.ShardRegion +import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.ShardRegion.ShardId +import akka.persistence.DeleteMessagesFailure +import akka.persistence.DeleteMessagesSuccess +import akka.persistence.DeleteSnapshotsFailure +import akka.persistence.DeleteSnapshotsSuccess +import akka.persistence.PersistentActor +import akka.persistence.RecoveryCompleted +import akka.persistence.SaveSnapshotFailure +import akka.persistence.SaveSnapshotSuccess +import akka.persistence.SnapshotOffer +import akka.persistence.SnapshotSelectionCriteria + +/** + * INTERNAL API + */ +@InternalApi +private[akka] final class EventSourcedRememberEntitiesStoreProvider(typeName: String, settings: ClusterShardingSettings) + extends RememberEntitiesShardStoreProvider { + + override def shardStoreProps(shardId: ShardId): Props = + EventSourcedRememberEntitiesStore.props(typeName, shardId, settings) + +} + +/** + * INTERNAL API + */ +private[akka] object EventSourcedRememberEntitiesStore { + + /** + * A case class which represents a state change for the Shard + */ + sealed trait StateChange extends ClusterShardingSerializable { + val entityId: EntityId + } + + /** + * Persistent state of the Shard. + */ + final case class State private[akka] (entities: Set[EntityId] = Set.empty) extends ClusterShardingSerializable + + /** + * `State` change for starting an entity in this `Shard` + */ + final case class EntityStarted(entityId: EntityId) extends StateChange + + case object StartedAck + + /** + * `State` change for an entity which has terminated. + */ + final case class EntityStopped(entityId: EntityId) extends StateChange + + def props(typeName: String, shardId: ShardRegion.ShardId, settings: ClusterShardingSettings): Props = + Props(new EventSourcedRememberEntitiesStore(typeName, shardId, settings)) +} + +/** + * INTERNAL API + * + * Persistent actor keeping the state for Akka Persistence backed remember entities (enabled through `state-store-mode=persistence`). + * + * @see [[ClusterSharding$ ClusterSharding extension]] + */ +private[akka] final class EventSourcedRememberEntitiesStore( + typeName: String, + shardId: ShardRegion.ShardId, + settings: ClusterShardingSettings) + extends PersistentActor + with ActorLogging { + + import EventSourcedRememberEntitiesStore._ + import settings.tuningParameters._ + + log.debug("Starting up EventSourcedRememberEntitiesStore") + private var state = State() + override def persistenceId = s"/sharding/${typeName}Shard/$shardId" + override def journalPluginId: String = settings.journalPluginId + override def snapshotPluginId: String = settings.snapshotPluginId + + override def receiveRecover: Receive = { + case EntityStarted(id) => state = state.copy(state.entities + id) + case EntityStopped(id) => state = state.copy(state.entities - id) + case SnapshotOffer(_, snapshot: State) => state = snapshot + case RecoveryCompleted => + log.debug("Recovery completed for shard [{}] with [{}] entities", shardId, state.entities.size) + } + + override def receiveCommand: Receive = { + case RememberEntitiesShardStore.AddEntity(id) => + persist(EntityStarted(id)) { started => + sender() ! RememberEntitiesShardStore.UpdateDone(id) + state.copy(state.entities + started.entityId) + saveSnapshotWhenNeeded() + } + case RememberEntitiesShardStore.RemoveEntity(id) => + persist(EntityStopped(id)) { stopped => + sender() ! RememberEntitiesShardStore.UpdateDone(id) + state.copy(state.entities - stopped.entityId) + saveSnapshotWhenNeeded() + } + case RememberEntitiesShardStore.GetEntities => + sender() ! RememberEntitiesShardStore.RememberedEntities(state.entities) + + case e: SaveSnapshotSuccess => + log.debug("Snapshot saved successfully") + internalDeleteMessagesBeforeSnapshot(e, keepNrOfBatches, snapshotAfter) + + case SaveSnapshotFailure(_, reason) => + log.warning("Snapshot failure: [{}]", reason.getMessage) + + case DeleteMessagesSuccess(toSequenceNr) => + val deleteTo = toSequenceNr - 1 + val deleteFrom = math.max(0, deleteTo - (keepNrOfBatches * snapshotAfter)) + log.debug( + "Messages to [{}] deleted successfully. Deleting snapshots from [{}] to [{}]", + toSequenceNr, + deleteFrom, + deleteTo) + deleteSnapshots(SnapshotSelectionCriteria(minSequenceNr = deleteFrom, maxSequenceNr = deleteTo)) + + case DeleteMessagesFailure(reason, toSequenceNr) => + log.warning("Messages to [{}] deletion failure: [{}]", toSequenceNr, reason.getMessage) + + case DeleteSnapshotsSuccess(m) => + log.debug("Snapshots matching [{}] deleted successfully", m) + + case DeleteSnapshotsFailure(m, reason) => + log.warning("Snapshots matching [{}] deletion failure: [{}]", m, reason.getMessage) + + } + + override def postStop(): Unit = { + super.postStop() + log.debug("Store stopping") + } + + def saveSnapshotWhenNeeded(): Unit = { + if (lastSequenceNr % snapshotAfter == 0 && lastSequenceNr != 0) { + log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr) + saveSnapshot(state) + } + } + +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala new file mode 100644 index 0000000000..2fa470a35f --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.ActorRef +import akka.actor.NoSerializationVerificationNeeded +import akka.actor.Props +import akka.annotation.InternalApi +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.ShardRegion + +import scala.collection.immutable.Set + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object RememberEntityStarter { + def props(region: ActorRef, ids: Set[ShardRegion.EntityId], settings: ClusterShardingSettings, requestor: ActorRef) = + Props(new RememberEntityStarter(region, ids, settings, requestor)) + + private case object Tick extends NoSerializationVerificationNeeded +} + +/** + * INTERNAL API: Actor responsible for starting entities when rememberEntities is enabled + */ +@InternalApi +private[akka] final class RememberEntityStarter( + region: ActorRef, + ids: Set[ShardRegion.EntityId], + settings: ClusterShardingSettings, + requestor: ActorRef) + extends Actor + with ActorLogging { + + import RememberEntityStarter.Tick + import context.dispatcher + + var waitingForAck = ids + + sendStart(ids) + + val tickTask = { + val resendInterval = settings.tuningParameters.retryInterval + context.system.scheduler.scheduleWithFixedDelay(resendInterval, resendInterval, self, Tick) + } + + def sendStart(ids: Set[ShardRegion.EntityId]): Unit = { + // these go through the region rather the directly to the shard + // so that shard mapping changes are picked up + ids.foreach(id => region ! ShardRegion.StartEntity(id)) + } + + override def receive: Receive = { + case ack: ShardRegion.StartEntityAck => + waitingForAck -= ack.entityId + // inform whoever requested the start that it happened + requestor ! ack + if (waitingForAck.isEmpty) context.stop(self) + + case Tick => + sendStart(waitingForAck) + + } + + override def postStop(): Unit = { + tickTask.cancel() + } +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStore.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStore.scala new file mode 100644 index 0000000000..a9e3e122c2 --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStore.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.Props +import akka.annotation.InternalApi +import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.ShardRegion.ShardId + +/** + * INTERNAL API + * + * Created once from the shard region, called once per started shard to create the remember entities shard store + */ +@InternalApi +private[akka] trait RememberEntitiesShardStoreProvider { + def shardStoreProps(shardId: ShardId): Props +} + +/** + * INTERNAL API + * + * Could potentially become an open SPI in the future. + * + * Implementations are responsible for each of the methods failing the returned future after a timeout. + */ +@InternalApi +private[akka] object RememberEntitiesShardStore { + // SPI protocol for a remember entities store + sealed trait Command + + sealed trait UpdateEntityCommand extends Command { + def entityId: EntityId + } + final case class AddEntity(entityId: EntityId) extends UpdateEntityCommand + final case class RemoveEntity(entityId: EntityId) extends UpdateEntityCommand + // responses for UpdateEntity add and remove + final case class UpdateDone(entityId: EntityId) + + case object GetEntities extends Command + final case class RememberedEntities(entities: Set[EntityId]) + +} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala index e85d6116a0..10af80a1a5 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala @@ -40,7 +40,11 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy import ShardCoordinator.Internal._ import Shard.{ CurrentShardState, GetCurrentShardState } import Shard.{ GetShardStats, ShardStats } - import Shard.{ State => EntityState, EntityStarted, EntityStopped } + import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStore.{ + State => EntityState, + EntityStarted, + EntityStopped + } private final val BufferSize = 1024 * 4 diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingFailureSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingFailureSpec.scala index eb4753e34e..e8a665b69f 100644 --- a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingFailureSpec.scala +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingFailureSpec.scala @@ -6,6 +6,7 @@ package akka.cluster.sharding import akka.actor._ import akka.cluster.sharding.ShardRegion.Passivate +import akka.cluster.sharding.ShardRegion.StartEntity import akka.remote.testconductor.RoleName import akka.remote.transport.ThrottlerTransportAdapter.Direction import akka.serialization.jackson.CborSerializable @@ -20,13 +21,20 @@ object ClusterShardingFailureSpec { case class Add(id: String, i: Int) extends CborSerializable case class Value(id: String, n: Int) extends CborSerializable - class Entity extends Actor { + class Entity extends Actor with ActorLogging { + log.debug("Starting") var n = 0 def receive = { - case Get(id) => sender() ! Value(id, n) - case Add(_, i) => n += i + case Get(id) => + log.debug("Got get request from {}", sender()) + sender() ! Value(id, n) + case Add(_, i) => + n += i + log.debug("Got add request from {}", sender()) } + + override def postStop(): Unit = log.debug("Stopping") } val extractEntityId: ShardRegion.ExtractEntityId = { @@ -35,8 +43,9 @@ object ClusterShardingFailureSpec { } val extractShardId: ShardRegion.ExtractShardId = { - case Get(id) => id.charAt(0).toString - case Add(id, _) => id.charAt(0).toString + case Get(id) => id.charAt(0).toString + case Add(id, _) => id.charAt(0).toString + case StartEntity(id) => id } } @@ -44,11 +53,14 @@ abstract class ClusterShardingFailureSpecConfig(override val mode: String) extends MultiNodeClusterShardingConfig( mode, additionalConfig = s""" + akka.loglevel=DEBUG akka.cluster.roles = ["backend"] akka.cluster.sharding { coordinator-failure-backoff = 3s shard-failure-backoff = 3s } + # don't leak ddata state across runs + akka.cluster.sharding.distributed-data.durable.keys = [] akka.persistence.journal.leveldb-shared.store.native = off # using Java serialization for these messages because test is sending them # to other nodes, which isn't normal usage. diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala index dceebb54d5..b8559b4668 100644 --- a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala @@ -130,7 +130,7 @@ abstract class MultiNodeClusterShardingSpec(val config: MultiNodeClusterSharding if (assertNodeUp) { within(max) { awaitAssert { - cluster.state.isMemberUp(node(from).address) + cluster.state.isMemberUp(node(from).address) should ===(true) } } } diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala index 48767fba92..c3bc9c804b 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala @@ -5,6 +5,7 @@ package akka.cluster.sharding import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.internal.EntityRecoveryStrategy import akka.testkit.AkkaSpec class AllAtOnceEntityRecoveryStrategySpec extends AkkaSpec { diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingInternalsSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingInternalsSpec.scala index e79f4ba394..992647ce92 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingInternalsSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingInternalsSpec.scala @@ -9,6 +9,7 @@ import akka.cluster.ClusterSettings.DataCenter import akka.cluster.sharding.ShardCoordinator.Internal.ShardStopped import akka.cluster.sharding.ShardCoordinator.LeastShardAllocationStrategy import akka.cluster.sharding.ShardRegion.{ ExtractEntityId, ExtractShardId, HandOffStopper, Msg } +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, TestProbe } import scala.concurrent.duration._ @@ -30,7 +31,9 @@ class ClusterShardingInternalsSpec extends AkkaSpec(""" |akka.actor.provider = cluster |akka.remote.classic.netty.tcp.port = 0 |akka.remote.artery.canonical.port = 0 - |""".stripMargin) { + |akka.loglevel = DEBUG + |akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + |""".stripMargin) with WithLogCapturing { import ClusterShardingInternalsSpec._ case class StartingProxy( diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala index 97bc907eda..5212d386b0 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala @@ -6,6 +6,7 @@ package akka.cluster.sharding import akka.actor.Props import akka.cluster.{ Cluster, MemberStatus, TestLease, TestLeaseExt } import akka.testkit.TestActors.EchoActor +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, ImplicitSender } import com.typesafe.config.{ Config, ConfigFactory } @@ -17,7 +18,7 @@ import scala.util.control.NoStackTrace object ClusterShardingLeaseSpec { val config = ConfigFactory.parseString(""" akka.loglevel = DEBUG - #akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 @@ -59,7 +60,8 @@ class DDataClusterShardingLeaseSpec extends ClusterShardingLeaseSpec(ClusterShar class ClusterShardingLeaseSpec(config: Config, rememberEntities: Boolean) extends AkkaSpec(config.withFallback(ClusterShardingLeaseSpec.config)) - with ImplicitSender { + with ImplicitSender + with WithLogCapturing { import ClusterShardingLeaseSpec._ def this() = this(ConfigFactory.empty(), false) @@ -129,7 +131,7 @@ class ClusterShardingLeaseSpec(config: Config, rememberEntities: Boolean) awaitAssert({ region ! 4 expectMsg(4) - }, max = 5.seconds) + }, max = 10.seconds) } } } diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConcurrentStartupShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConcurrentStartupShardingSpec.scala index 85d64dfe0a..90a0337592 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConcurrentStartupShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConcurrentStartupShardingSpec.scala @@ -5,7 +5,6 @@ package akka.cluster.sharding import scala.concurrent.duration._ - import akka.actor.Actor import akka.actor.ActorRef import akka.actor.Props @@ -14,12 +13,15 @@ import akka.cluster.MemberStatus import akka.testkit.AkkaSpec import akka.testkit.DeadLettersFilter import akka.testkit.TestEvent.Mute +import akka.testkit.WithLogCapturing object ConcurrentStartupShardingSpec { val config = """ akka.actor.provider = "cluster" + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 akka.log-dead-letters = off @@ -57,7 +59,7 @@ object ConcurrentStartupShardingSpec { } } -class ConcurrentStartupShardingSpec extends AkkaSpec(ConcurrentStartupShardingSpec.config) { +class ConcurrentStartupShardingSpec extends AkkaSpec(ConcurrentStartupShardingSpec.config) with WithLogCapturing { import ConcurrentStartupShardingSpec._ // mute logging of deadLetters diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala index bc4987576b..0d8ac63417 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala @@ -5,6 +5,7 @@ package akka.cluster.sharding import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.internal.EntityRecoveryStrategy import akka.testkit.{ AkkaSpec, TimingTest } import scala.concurrent.{ Await, Future } diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/CoordinatedShutdownShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/CoordinatedShutdownShardingSpec.scala index 3d52cd236a..4a3cff57be 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/CoordinatedShutdownShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/CoordinatedShutdownShardingSpec.scala @@ -21,6 +21,7 @@ import akka.util.ccompat._ object CoordinatedShutdownShardingSpec { val config = """ + akka.loglevel = DEBUG akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/GetShardTypeNamesSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/GetShardTypeNamesSpec.scala index 52e84c1e8a..084d193d33 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/GetShardTypeNamesSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/GetShardTypeNamesSpec.scala @@ -8,11 +8,13 @@ import akka.actor.Props import akka.cluster.Cluster import akka.testkit.AkkaSpec import akka.testkit.TestActors.EchoActor +import akka.testkit.WithLogCapturing object GetShardTypeNamesSpec { val config = """ - akka.loglevel = INFO + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 @@ -27,7 +29,7 @@ object GetShardTypeNamesSpec { } } -class GetShardTypeNamesSpec extends AkkaSpec(GetShardTypeNamesSpec.config) { +class GetShardTypeNamesSpec extends AkkaSpec(GetShardTypeNamesSpec.config) with WithLogCapturing { import GetShardTypeNamesSpec._ "GetShardTypeNames" must { diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/InactiveEntityPassivationSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/InactiveEntityPassivationSpec.scala index 05deedf91c..3260a1d6fb 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/InactiveEntityPassivationSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/InactiveEntityPassivationSpec.scala @@ -5,10 +5,10 @@ package akka.cluster.sharding import scala.concurrent.duration._ - import akka.actor.{ Actor, ActorRef, Props } import akka.cluster.Cluster import akka.cluster.sharding.InactiveEntityPassivationSpec.Entity.GotIt +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, TestProbe } import com.typesafe.config.ConfigFactory import com.typesafe.config.Config @@ -16,7 +16,8 @@ import com.typesafe.config.Config object InactiveEntityPassivationSpec { val config = ConfigFactory.parseString(""" - akka.loglevel = INFO + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 @@ -54,7 +55,7 @@ object InactiveEntityPassivationSpec { } } -abstract class AbstractInactiveEntityPassivationSpec(c: Config) extends AkkaSpec(c) { +abstract class AbstractInactiveEntityPassivationSpec(c: Config) extends AkkaSpec(c) with WithLogCapturing { import InactiveEntityPassivationSpec._ private val smallTolerance = 300.millis diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala index 6dc62ee66d..fb9eb40dd9 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala @@ -6,12 +6,14 @@ package akka.cluster.sharding import akka.actor.ActorSystem import akka.cluster.{ Cluster, ClusterReadView } +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, LongRunningTest } import com.typesafe.config.{ Config, ConfigFactory } + import scala.concurrent.duration._ import scala.collection.{ immutable => im } -class JoinConfigCompatCheckShardingSpec extends AkkaSpec() { +class JoinConfigCompatCheckShardingSpec extends AkkaSpec() with WithLogCapturing { def initCluster(system: ActorSystem): ClusterReadView = { val cluster = Cluster(system) @@ -24,6 +26,8 @@ class JoinConfigCompatCheckShardingSpec extends AkkaSpec() { val baseConfig: Config = ConfigFactory.parseString(""" akka.actor.provider = "cluster" + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.coordinated-shutdown.terminate-actor-system = on akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala deleted file mode 100644 index 592c017e90..0000000000 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (C) 2018-2020 Lightbend Inc. - */ - -package akka.cluster.sharding - -import akka.actor.{ Actor, PoisonPill, Props } -import akka.cluster.sharding.PersistentShardSpec.EntityActor -import akka.cluster.sharding.Shard.{ GetShardStats, ShardStats } -import akka.cluster.sharding.ShardRegion.{ StartEntity, StartEntityAck } -import akka.testkit.{ AkkaSpec, ImplicitSender } -import com.typesafe.config.ConfigFactory -import org.scalatest.wordspec.AnyWordSpecLike - -object PersistentShardSpec { - class EntityActor extends Actor { - override def receive: Receive = { - case _ => - } - } - - val config = ConfigFactory.parseString(""" - akka.persistence.journal.plugin = "akka.persistence.journal.inmem" - """.stripMargin) -} - -class PersistentShardSpec extends AkkaSpec(PersistentShardSpec.config) with AnyWordSpecLike with ImplicitSender { - - "Persistent Shard" must { - - "remember entities started with StartEntity" in { - val props = - Props(new PersistentShard("cats", "shard-1", _ => Props(new EntityActor), ClusterShardingSettings(system), { - case _ => ("entity-1", "msg") - }, { _ => - "shard-1" - }, PoisonPill)) - val persistentShard = system.actorOf(props) - watch(persistentShard) - - persistentShard ! StartEntity("entity-1") - expectMsg(StartEntityAck("entity-1", "shard-1")) - - persistentShard ! PoisonPill - expectTerminated(persistentShard) - - system.log.info("Starting shard again") - val secondIncarnation = system.actorOf(props) - - secondIncarnation ! GetShardStats - awaitAssert(expectMsg(ShardStats("shard-1", 1))) - } - } - -} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala index e7397bdcd6..fe660ff1ca 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala @@ -9,17 +9,21 @@ import scala.concurrent.duration._ import akka.actor.ActorRef import akka.testkit.AkkaSpec import akka.testkit.TestActors +import akka.testkit.WithLogCapturing + import scala.concurrent.duration.FiniteDuration object ProxyShardingSpec { val config = """ - akka.actor.provider = "cluster" + akka.actor.provider = cluster + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 """ } -class ProxyShardingSpec extends AkkaSpec(ProxyShardingSpec.config) { +class ProxyShardingSpec extends AkkaSpec(ProxyShardingSpec.config) with WithLogCapturing { val role = "Shard" val clusterSharding: ClusterSharding = ClusterSharding(system) diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesFailureSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesFailureSpec.scala new file mode 100644 index 0000000000..3a7a603dd2 --- /dev/null +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesFailureSpec.scala @@ -0,0 +1,290 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import akka.Done +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.ActorRef +import akka.actor.Props +import akka.cluster.Cluster +import akka.cluster.MemberStatus +import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.ShardRegion.ShardId +import akka.cluster.sharding.internal.RememberEntitiesShardStore +import akka.cluster.sharding.internal.RememberEntitiesShardStoreProvider +import akka.testkit.AkkaSpec +import akka.testkit.TestException +import akka.testkit.TestProbe +import akka.testkit.WithLogCapturing +import com.github.ghik.silencer.silent +import com.typesafe.config.ConfigFactory +import org.scalatest.wordspec.AnyWordSpecLike + +import scala.concurrent.duration._ + +object RememberEntitiesFailureSpec { + val config = ConfigFactory.parseString(s""" + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + akka.actor.provider = cluster + akka.remote.artery.canonical.port = 0 + akka.remote.classic.netty.tcp.port = 0 + akka.cluster.sharding.distributed-data.durable.keys = [] + akka.cluster.sharding.state-store-mode = custom + akka.cluster.sharding.custom-store = "akka.cluster.sharding.RememberEntitiesFailureSpec$$FakeStore" + # quick backoffs + akka.cluster.sharding.entity-restart-backoff = 1s + akka.cluster.sharding.shard-failure-backoff = 1s + """) + + class EntityActor extends Actor with ActorLogging { + override def receive: Receive = { + case "stop" => + log.info("Stopping myself!") + context.stop(self) + case "graceful-stop" => + context.parent ! ShardRegion.Passivate("stop") + case msg => sender() ! msg + } + } + + case class EntityEnvelope(entityId: Int, msg: Any) + + val extractEntityId: ShardRegion.ExtractEntityId = { + case EntityEnvelope(id, payload) => (id.toString, payload) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case EntityEnvelope(id, _) => (id % 10).toString + } + + sealed trait Fail + case object NoResponse extends Fail + case object CrashStore extends Fail + + // outside store since shard allocation triggers initialization of store and we can't interact with the fake store actor before that + @volatile var failInitial = Map.empty[ShardId, Fail] + + case class StoreCreated(store: ActorRef, shardId: ShardId) + @silent("never used") + class FakeStore(settings: ClusterShardingSettings, typeName: String) extends RememberEntitiesShardStoreProvider { + override def shardStoreProps(shardId: ShardId): Props = FakeStoreActor.props(shardId) + } + object FakeStoreActor { + def props(shardId: ShardId): Props = Props(new FakeStoreActor(shardId)) + + case class FailAddEntity(entityId: EntityId, whichWay: Fail) + case class DoNotFailAddEntity(entityId: EntityId) + case class FailRemoveEntity(entityId: EntityId, whichWay: Fail) + case class DoNotFailRemoveEntity(entityId: EntityId) + } + class FakeStoreActor(shardId: ShardId) extends Actor with ActorLogging { + import FakeStoreActor._ + + implicit val ec = context.system.dispatcher + private var failAddEntity = Map.empty[EntityId, Fail] + private var failRemoveEntity = Map.empty[EntityId, Fail] + + context.system.eventStream.publish(StoreCreated(self, shardId)) + + override def receive: Receive = { + case RememberEntitiesShardStore.GetEntities => + failInitial.get(shardId) match { + case None => sender ! RememberEntitiesShardStore.RememberedEntities(Set.empty) + case Some(NoResponse) => log.debug("Sending no response for GetEntities") + case Some(CrashStore) => throw TestException("store crash on GetEntities") + } + case RememberEntitiesShardStore.AddEntity(entityId) => + failAddEntity.get(entityId) match { + case None => sender ! RememberEntitiesShardStore.UpdateDone(entityId) + case Some(NoResponse) => log.debug("Sending no response for AddEntity") + case Some(CrashStore) => throw TestException("store crash on AddEntity") + } + case RememberEntitiesShardStore.RemoveEntity(entityId) => + failRemoveEntity.get(entityId) match { + case None => sender ! RememberEntitiesShardStore.UpdateDone(entityId) + case Some(NoResponse) => log.debug("Sending no response for RemoveEntity") + case Some(CrashStore) => throw TestException("store crash on AddEntity") + } + case FailAddEntity(id, whichWay) => + failAddEntity = failAddEntity.updated(id, whichWay) + sender() ! Done + case DoNotFailAddEntity(id) => + failAddEntity = failAddEntity - id + sender() ! Done + case FailRemoveEntity(id, whichWay) => + failRemoveEntity = failRemoveEntity.updated(id, whichWay) + sender() ! Done + case DoNotFailRemoveEntity(id) => + failRemoveEntity = failRemoveEntity - id + sender() ! Done + + } + } +} + +class RememberEntitiesFailureSpec + extends AkkaSpec(RememberEntitiesFailureSpec.config) + with AnyWordSpecLike + with WithLogCapturing { + + import RememberEntitiesFailureSpec._ + + override def atStartup(): Unit = { + // Form a one node cluster + val cluster = Cluster(system) + cluster.join(cluster.selfAddress) + awaitAssert(cluster.readView.members.count(_.status == MemberStatus.Up) should ===(1)) + } + + "Remember entities handling in sharding" must { + + List(NoResponse, CrashStore).foreach { wayToFail: Fail => + s"recover when initial remember entities load fails $wayToFail" in { + log.debug("Getting entities for shard 1 will fail") + failInitial = Map("1" -> wayToFail) + + try { + val probe = TestProbe() + val sharding = ClusterSharding(system).start( + s"initial-$wayToFail", + Props[EntityActor], + ClusterShardingSettings(system).withRememberEntities(true), + extractEntityId, + extractShardId) + + sharding.tell(EntityEnvelope(1, "hello-1"), probe.ref) + probe.expectNoMessage() // message is lost because shard crashes + + log.debug("Resetting initial fail") + failInitial = Map.empty + + // shard should be restarted and eventually succeed + awaitAssert { + sharding.tell(EntityEnvelope(1, "hello-1"), probe.ref) + probe.expectMsg("hello-1") + } + + system.stop(sharding) + } finally { + failInitial = Map.empty + } + } + + s"recover when storing a start event fails $wayToFail" in { + val storeProbe = TestProbe() + system.eventStream.subscribe(storeProbe.ref, classOf[StoreCreated]) + + val sharding = ClusterSharding(system).start( + s"storeStart-$wayToFail", + Props[EntityActor], + ClusterShardingSettings(system).withRememberEntities(true), + extractEntityId, + extractShardId) + + // trigger shard start and store creation + val probe = TestProbe() + sharding.tell(EntityEnvelope(1, "hello-1"), probe.ref) + val shard1Store = storeProbe.expectMsgType[StoreCreated].store + probe.expectMsg("hello-1") + + // hit shard with other entity that will fail + shard1Store.tell(FakeStoreActor.FailAddEntity("11", wayToFail), storeProbe.ref) + storeProbe.expectMsg(Done) + + sharding.tell(EntityEnvelope(11, "hello-11"), probe.ref) + + // do we get an answer here? shard crashes + probe.expectNoMessage() + + val stopFailingProbe = TestProbe() + shard1Store.tell(FakeStoreActor.DoNotFailAddEntity("11"), stopFailingProbe.ref) + stopFailingProbe.expectMsg(Done) + + // it takes a while - timeout hits and then backoff + awaitAssert({ + sharding.tell(EntityEnvelope(11, "hello-11-2"), probe.ref) + probe.expectMsg("hello-11-2") + }, 10.seconds) + system.stop(sharding) + } + + s"recover on abrupt entity stop when storing a stop event fails $wayToFail" in { + val storeProbe = TestProbe() + system.eventStream.subscribe(storeProbe.ref, classOf[StoreCreated]) + + val sharding = ClusterSharding(system).start( + s"storeStopAbrupt-$wayToFail", + Props[EntityActor], + ClusterShardingSettings(system).withRememberEntities(true), + extractEntityId, + extractShardId) + + val probe = TestProbe() + + // trigger shard start and store creation + sharding.tell(EntityEnvelope(1, "hello-1"), probe.ref) + val shard1Store = storeProbe.expectMsgType[StoreCreated].store + probe.expectMsg("hello-1") + + // fail it when stopping + shard1Store.tell(FakeStoreActor.FailRemoveEntity("1", wayToFail), storeProbe.ref) + storeProbe.expectMsg(Done) + + // FIXME restart without passivating is not saved and re-started again without storing the stop so this isn't testing anything + sharding ! EntityEnvelope(1, "stop") + + shard1Store.tell(FakeStoreActor.DoNotFailRemoveEntity("1"), storeProbe.ref) + storeProbe.expectMsg(Done) + + // it takes a while - timeout hits and then backoff + awaitAssert({ + sharding.tell(EntityEnvelope(1, "hello-2"), probe.ref) + probe.expectMsg("hello-2") + }, 10.seconds) + system.stop(sharding) + } + + s"recover on graceful entity stop when storing a stop event fails $wayToFail" in { + val storeProbe = TestProbe() + system.eventStream.subscribe(storeProbe.ref, classOf[StoreCreated]) + + val sharding = ClusterSharding(system).start( + s"storeStopGraceful-$wayToFail", + Props[EntityActor], + ClusterShardingSettings(system).withRememberEntities(true), + extractEntityId, + extractShardId, + new ShardCoordinator.LeastShardAllocationStrategy(rebalanceThreshold = 1, maxSimultaneousRebalance = 3), + "graceful-stop") + + val probe = TestProbe() + + // trigger shard start and store creation + sharding.tell(EntityEnvelope(1, "hello-1"), probe.ref) + val shard1Store = storeProbe.expectMsgType[StoreCreated].store + probe.expectMsg("hello-1") + + // fail it when stopping + shard1Store.tell(FakeStoreActor.FailRemoveEntity("1", wayToFail), storeProbe.ref) + storeProbe.expectMsg(Done) + + sharding ! EntityEnvelope(1, "graceful-stop") + + shard1Store.tell(FakeStoreActor.DoNotFailRemoveEntity("1"), storeProbe.ref) + storeProbe.expectMsg(Done) + + // it takes a while? + awaitAssert({ + sharding.tell(EntityEnvelope(1, "hello-2"), probe.ref) + probe.expectMsg("hello-2") + }, 5.seconds) + system.stop(sharding) + } + } + } + +} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesSpec.scala new file mode 100644 index 0000000000..02acce0688 --- /dev/null +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RememberEntitiesSpec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2018-2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.PoisonPill +import akka.actor.Props +import akka.cluster.Cluster +import akka.cluster.MemberStatus +import akka.cluster.sharding.Shard.GetShardStats +import akka.cluster.sharding.Shard.ShardStats +import akka.cluster.sharding.ShardRegion.StartEntity +import akka.cluster.sharding.ShardRegion.StartEntityAck +import akka.testkit.AkkaSpec +import akka.testkit.ImplicitSender +import akka.testkit.WithLogCapturing +import com.typesafe.config.ConfigFactory +import org.scalatest.wordspec.AnyWordSpecLike + +object RememberEntitiesSpec { + class EntityActor extends Actor { + override def receive: Receive = { + case "give-me-shard" => sender() ! context.parent + case msg => sender() ! msg + } + } + + case class EntityEnvelope(entityId: Int, msg: Any) + + val extractEntityId: ShardRegion.ExtractEntityId = { + case EntityEnvelope(id, payload) => (id.toString, payload) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case EntityEnvelope(id, _) => (id % 10).toString + case StartEntity(id) => (id.toInt % 10).toString + } + + val config = ConfigFactory.parseString(""" + akka.loglevel=DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + akka.actor.provider = cluster + akka.remote.artery.canonical.port = 0 + akka.remote.classic.netty.tcp.port = 0 + akka.persistence.journal.plugin = "akka.persistence.journal.inmem" + """.stripMargin) +} + +class RememberEntitiesSpec + extends AkkaSpec(RememberEntitiesSpec.config) + with AnyWordSpecLike + with ImplicitSender + with WithLogCapturing { + + import RememberEntitiesSpec._ + + override def atStartup(): Unit = { + // Form a one node cluster + val cluster = Cluster(system) + cluster.join(cluster.selfAddress) + awaitAssert(cluster.readView.members.count(_.status == MemberStatus.Up) should ===(1)) + } + + "Persistent Shard" must { + + "remember entities started with StartEntity" in { + val sharding = ClusterSharding(system).start( + s"startEntity", + Props[EntityActor], + ClusterShardingSettings(system) + .withRememberEntities(true) + .withStateStoreMode(ClusterShardingSettings.StateStoreModePersistence), + extractEntityId, + extractShardId) + + sharding ! StartEntity("1") + expectMsg(StartEntityAck("1", "1")) + val shard = lastSender + + watch(shard) + shard ! PoisonPill + expectTerminated(shard) + + // trigger shard start by messaging other actor in it + system.log.info("Starting shard again") + sharding ! EntityEnvelope(11, "give-me-shard") + val secondShardIncarnation = expectMsgType[ActorRef] + + awaitAssert { + secondShardIncarnation ! GetShardStats + // the remembered 1 and 11 which we just triggered start of + expectMsg(ShardStats("1", 2)) + } + } + } + +} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RemoveInternalClusterShardingDataSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RemoveInternalClusterShardingDataSpec.scala index 375a8f27f2..0c4cdee779 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RemoveInternalClusterShardingDataSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/RemoveInternalClusterShardingDataSpec.scala @@ -9,7 +9,6 @@ import java.io.File import scala.concurrent.Await import scala.concurrent.duration._ import scala.util.Success - import akka.actor.ActorRef import akka.actor.Props import akka.cluster.Cluster @@ -23,11 +22,13 @@ import akka.persistence.SnapshotSelectionCriteria import akka.testkit.AkkaSpec import akka.testkit.ImplicitSender import akka.testkit.TestActors.EchoActor +import akka.testkit.WithLogCapturing import org.apache.commons.io.FileUtils object RemoveInternalClusterShardingDataSpec { val config = """ - akka.loglevel = INFO + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 @@ -93,7 +94,8 @@ object RemoveInternalClusterShardingDataSpec { class RemoveInternalClusterShardingDataSpec extends AkkaSpec(RemoveInternalClusterShardingDataSpec.config) - with ImplicitSender { + with ImplicitSender + with WithLogCapturing { import RemoveInternalClusterShardingDataSpec._ val storageLocations = diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala index 0c003eae72..62acdbdacb 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala @@ -10,6 +10,7 @@ import akka.actor.{ Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Prop import akka.cluster.ClusterEvent.CurrentClusterState import akka.cluster.{ Cluster, MemberStatus } import akka.testkit.TestEvent.Mute +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, DeadLettersFilter, TestProbe } import com.typesafe.config.ConfigFactory import org.apache.commons.io.FileUtils @@ -24,7 +25,8 @@ object ShardRegionSpec { val config = ConfigFactory.parseString(tempConfig).withFallback(ConfigFactory.parseString(s""" - akka.loglevel = INFO + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] akka.actor.provider = "cluster" akka.remote.classic.netty.tcp.port = 0 akka.remote.artery.canonical.port = 0 @@ -57,7 +59,7 @@ object ShardRegionSpec { } } } -class ShardRegionSpec extends AkkaSpec(ShardRegionSpec.config) { +class ShardRegionSpec extends AkkaSpec(ShardRegionSpec.config) with WithLogCapturing { import ShardRegionSpec._ import scala.concurrent.duration._ diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardSpec.scala deleted file mode 100644 index 321f8f6178..0000000000 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardSpec.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (C) 2019-2020 Lightbend Inc. - */ - -package akka.cluster.sharding - -import java.util.concurrent.atomic.AtomicInteger - -import akka.actor.{ Actor, ActorLogging, PoisonPill, Props } -import akka.cluster.TestLeaseExt -import akka.cluster.sharding.ShardRegion.ShardInitialized -import akka.coordination.lease.LeaseUsageSettings -import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe } - -import scala.concurrent.Future -import scala.concurrent.duration._ -import scala.util.Success -import scala.util.control.NoStackTrace - -object ShardSpec { - val config = - """ - akka.loglevel = INFO - akka.actor.provider = "cluster" - akka.remote.classic.netty.tcp.port = 0 - akka.remote.artery.canonical.port = 0 - test-lease { - lease-class = akka.cluster.TestLease - heartbeat-interval = 1s - heartbeat-timeout = 120s - lease-operation-timeout = 3s - } - """ - - class EntityActor extends Actor with ActorLogging { - override def receive: Receive = { - case msg => - log.info("Msg {}", msg) - sender() ! s"ack ${msg}" - } - } - - val numberOfShards = 5 - - case class EntityEnvelope(entityId: Int, msg: Any) - - val extractEntityId: ShardRegion.ExtractEntityId = { - case EntityEnvelope(id, payload) => (id.toString, payload) - } - - val extractShardId: ShardRegion.ExtractShardId = { - case EntityEnvelope(id, _) => (id % numberOfShards).toString - } - - case class BadLease(msg: String) extends RuntimeException(msg) with NoStackTrace -} - -class ShardSpec extends AkkaSpec(ShardSpec.config) with ImplicitSender { - - import ShardSpec._ - - val shortDuration = 100.millis - val testLeaseExt = TestLeaseExt(system) - - def leaseNameForShard(typeName: String, shardId: String) = s"${system.name}-shard-${typeName}-${shardId}" - - "A Cluster Shard" should { - "not initialize the shard until the lease is acquired" in new Setup { - parent.expectNoMessage(shortDuration) - lease.initialPromise.complete(Success(true)) - parent.expectMsg(ShardInitialized(shardId)) - } - - "retry if lease acquire returns false" in new Setup { - lease.initialPromise.complete(Success(false)) - parent.expectNoMessage(shortDuration) - lease.setNextAcquireResult(Future.successful(true)) - parent.expectMsg(ShardInitialized(shardId)) - } - - "retry if the lease acquire fails" in new Setup { - lease.initialPromise.failure(BadLease("no lease for you")) - parent.expectNoMessage(shortDuration) - lease.setNextAcquireResult(Future.successful(true)) - parent.expectMsg(ShardInitialized(shardId)) - } - - "shutdown if lease is lost" in new Setup { - val probe = TestProbe() - probe.watch(shard) - lease.initialPromise.complete(Success(true)) - parent.expectMsg(ShardInitialized(shardId)) - lease.getCurrentCallback().apply(Some(BadLease("bye bye lease"))) - probe.expectTerminated(shard) - } - } - - val shardIds = new AtomicInteger(0) - def nextShardId = s"${shardIds.getAndIncrement()}" - - trait Setup { - val shardId = nextShardId - val parent = TestProbe() - val settings = ClusterShardingSettings(system).withLeaseSettings(new LeaseUsageSettings("test-lease", 2.seconds)) - def lease = awaitAssert { - testLeaseExt.getTestLease(leaseNameForShard(typeName, shardId)) - } - - val typeName = "type1" - val shard = parent.childActorOf( - Shard.props( - typeName, - shardId, - _ => Props(new EntityActor()), - settings, - extractEntityId, - extractShardId, - PoisonPill, - system.deadLetters, - 1)) - } - -} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardWithLeaseSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardWithLeaseSpec.scala new file mode 100644 index 0000000000..9796bf8fd4 --- /dev/null +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardWithLeaseSpec.scala @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2019-2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.Props +import akka.cluster.Cluster +import akka.cluster.MemberStatus +import akka.cluster.TestLeaseExt +import akka.cluster.sharding.ShardRegion.ShardId +import akka.coordination.lease.LeaseUsageSettings +import akka.testkit.AkkaSpec +import akka.testkit.EventFilter +import akka.testkit.TestProbe +import akka.testkit.WithLogCapturing + +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.util.Success +import scala.util.control.NoStackTrace + +// FIXME this looks like it is the same test as ClusterShardingLeaseSpec is there any difference? +object ShardWithLeaseSpec { + val config = + """ + akka.loglevel = DEBUG + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + akka.actor.provider = "cluster" + akka.remote.classic.netty.tcp.port = 0 + akka.remote.artery.canonical.port = 0 + test-lease { + lease-class = akka.cluster.TestLease + heartbeat-interval = 1s + heartbeat-timeout = 120s + lease-operation-timeout = 3s + } + """ + + class EntityActor extends Actor with ActorLogging { + override def receive: Receive = { + case msg => + log.info("Msg {}", msg) + sender() ! s"ack ${msg}" + } + } + + val numberOfShards = 5 + + case class EntityEnvelope(entityId: Int, msg: Any) + + val extractEntityId: ShardRegion.ExtractEntityId = { + case EntityEnvelope(id, payload) => (id.toString, payload) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case EntityEnvelope(id, _) => (id % numberOfShards).toString + } + + case class BadLease(msg: String) extends RuntimeException(msg) with NoStackTrace +} + +class ShardWithLeaseSpec extends AkkaSpec(ShardWithLeaseSpec.config) with WithLogCapturing { + + import ShardWithLeaseSpec._ + + val shortDuration = 100.millis + val testLeaseExt = TestLeaseExt(system) + + override def atStartup(): Unit = { + // Form a one node cluster + val cluster = Cluster(system) + cluster.join(cluster.selfAddress) + awaitAssert(cluster.readView.members.count(_.status == MemberStatus.Up) should ===(1)) + } + + "Lease handling in sharding" must { + "not initialize the shard until the lease is acquired" in new Setup { + val probe = TestProbe() + sharding.tell(EntityEnvelope(1, "hello"), probe.ref) + probe.expectNoMessage(shortDuration) + leaseFor("1").initialPromise.complete(Success(true)) + probe.expectMsg("ack hello") + } + + "retry if lease acquire returns false" in new Setup { + val probe = TestProbe() + val lease = + EventFilter.error(start = s"Failed to get lease for shard type [$typeName] id [1]", occurrences = 1).intercept { + sharding.tell(EntityEnvelope(1, "hello"), probe.ref) + val lease = leaseFor("1") + lease.initialPromise.complete(Success(false)) + probe.expectNoMessage(shortDuration) + lease + } + + lease.setNextAcquireResult(Future.successful(true)) + probe.expectMsg("ack hello") + } + + "retry if the lease acquire fails" in new Setup { + val probe = TestProbe() + val lease = + EventFilter.error(start = s"Failed to get lease for shard type [$typeName] id [1]", occurrences = 1).intercept { + sharding.tell(EntityEnvelope(1, "hello"), probe.ref) + val lease = leaseFor("1") + lease.initialPromise.failure(BadLease("no lease for you")) + probe.expectNoMessage(shortDuration) + lease + } + lease.setNextAcquireResult(Future.successful(true)) + probe.expectMsg("ack hello") + } + + "shutdown if lease is lost" in new Setup { + val probe = TestProbe() + sharding.tell(EntityEnvelope(1, "hello"), probe.ref) + val lease = leaseFor("1") + lease.initialPromise.complete(Success(true)) + probe.expectMsg("ack hello") + + EventFilter + .error( + start = + s"Shard type [$typeName] id [1] lease lost, stopping shard and killing [1] entities. Reason for losing lease: ${classOf[ + BadLease].getName}: bye bye lease", + occurrences = 1) + .intercept { + lease.getCurrentCallback().apply(Some(BadLease("bye bye lease"))) + sharding.tell(EntityEnvelope(1, "hello"), probe.ref) + probe.expectNoMessage(shortDuration) + } + } + } + + var typeIdx = 0 + + trait Setup { + val settings = ClusterShardingSettings(system).withLeaseSettings(new LeaseUsageSettings("test-lease", 2.seconds)) + + // unique type name for each test + val typeName = { + typeIdx += 1 + s"type$typeIdx" + } + + val sharding = + ClusterSharding(system).start(typeName, Props(new EntityActor()), settings, extractEntityId, extractShardId) + + def leaseFor(shardId: ShardId) = awaitAssert { + val leaseName = s"${system.name}-shard-${typeName}-${shardId}" + testLeaseExt.getTestLease(leaseName) + } + } + +} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/SupervisionSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/SupervisionSpec.scala index 7c1c8c8258..93c1d3df04 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/SupervisionSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/SupervisionSpec.scala @@ -8,6 +8,7 @@ import akka.actor.{ Actor, ActorLogging, ActorRef, PoisonPill, Props } import akka.cluster.Cluster import akka.cluster.sharding.ShardRegion.Passivate import akka.pattern.{ BackoffOpts, BackoffSupervisor } +import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, ImplicitSender } import com.typesafe.config.ConfigFactory @@ -17,7 +18,10 @@ object SupervisionSpec { val config = ConfigFactory.parseString(""" akka.actor.provider = "cluster" - akka.loglevel = INFO + akka.remote.artery.canonical.port = 0 + akka.remote.classic.netty.tcp.port = 0 + akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] + akka.loglevel = DEBUG """) case class Msg(id: Long, msg: Any) @@ -48,6 +52,7 @@ object SupervisionSpec { context.parent ! Passivate(StopMessage) // simulate another message causing a stop before the region sends the stop message // e.g. a persistent actor having a persist failure while processing the next message + // note that this means the StopMessage will go to dead letters context.stop(self) case "hello" => sender() ! Response(self) @@ -59,8 +64,7 @@ object SupervisionSpec { } -class SupervisionSpec extends AkkaSpec(SupervisionSpec.config) with ImplicitSender { - +class DeprecatedSupervisionSpec extends AkkaSpec(SupervisionSpec.config) with ImplicitSender with WithLogCapturing { import SupervisionSpec._ "Supervision for a sharded actor (deprecated)" must { @@ -98,6 +102,11 @@ class SupervisionSpec extends AkkaSpec(SupervisionSpec.config) with ImplicitSend expectMsgType[Response](20.seconds) } } +} + +class SupervisionSpec extends AkkaSpec(SupervisionSpec.config) with ImplicitSender { + + import SupervisionSpec._ "Supervision for a sharded actor" must { @@ -125,10 +134,16 @@ class SupervisionSpec extends AkkaSpec(SupervisionSpec.config) with ImplicitSend val response = expectMsgType[Response](5.seconds) watch(response.self) + // 1. passivation message is passed on from supervisor to shard (which starts buffering messages for the entity id) + // 2. child stops + // 3. the supervisor has or has not yet seen gotten the stop message back from the shard + // a. if has it will stop immediatel, and the next message will trigger the shard to restart it + // b. if it hasn't the supervisor will back off before restarting the child, when the + // final stop message `StopMessage` comes in from the shard it will stop itself + // 4. when the supervisor stops the shard should start it anew and deliver the buffered messages region ! Msg(10, "passivate") expectTerminated(response.self) - // This would fail before as sharded actor would be stuck passivating region ! Msg(10, "hello") expectMsgType[Response](20.seconds) } diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala index 07616580ec..ea2404380a 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala @@ -5,7 +5,6 @@ package akka.cluster.sharding.protobuf import scala.concurrent.duration._ - import akka.actor.Address import akka.actor.ExtendedActorSystem import akka.actor.Props @@ -13,6 +12,7 @@ import akka.cluster.sharding.Shard import akka.cluster.sharding.ShardCoordinator import akka.cluster.sharding.ShardRegion import akka.cluster.sharding.ShardRegion.ShardId +import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStore import akka.serialization.SerializationExtension import akka.testkit.AkkaSpec @@ -70,12 +70,12 @@ class ClusterShardingMessageSerializerSpec extends AkkaSpec { } "be able to serialize PersistentShard snapshot state" in { - checkSerialization(Shard.State(Set("e1", "e2", "e3"))) + checkSerialization(EventSourcedRememberEntitiesStore.State(Set("e1", "e2", "e3"))) } "be able to serialize PersistentShard domain events" in { - checkSerialization(Shard.EntityStarted("e1")) - checkSerialization(Shard.EntityStopped("e1")) + checkSerialization(EventSourcedRememberEntitiesStore.EntityStarted("e1")) + checkSerialization(EventSourcedRememberEntitiesStore.EntityStopped("e1")) } "be able to serialize GetShardStats" in {