diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala index aec14d5eaa..7912dcecd9 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala @@ -20,14 +20,12 @@ import akka.actor.Terminated import akka.actor.Timers import akka.annotation.InternalStableApi import akka.cluster.Cluster -import akka.cluster.sharding.internal.EntityRecoveryStrategy import akka.cluster.sharding.internal.RememberEntitiesShardStore import akka.cluster.sharding.internal.RememberEntitiesShardStore.GetEntities import akka.cluster.sharding.internal.RememberEntitiesProvider import akka.cluster.sharding.internal.RememberEntityStarter import akka.coordination.lease.scaladsl.Lease import akka.coordination.lease.scaladsl.LeaseProvider -import akka.dispatch.ExecutionContexts import akka.event.LoggingAdapter import akka.pattern.pipe import akka.util.MessageBufferMap @@ -57,12 +55,6 @@ private[akka] object Shard { */ final case class RestartTerminatedEntity(entity: EntityId) extends RememberEntityCommand - /** - * When initialising a shard with remember entities enabled the following message is used - * to restart batches of entity actors at a time. - */ - final case class RestartRememberedEntities(entity: Set[EntityId]) extends RememberEntityCommand - /** * If the shard id extractor is changed, remembered entities will start in a different shard * and this message is sent to the shard to not leak `entityId -> RememberedButNotStarted` entries @@ -544,34 +536,16 @@ private[akka] class Shard( log.debug("Shard initialized") if (ids.nonEmpty) { entities.alreadyRemembered(ids) - restartRememberedEntities(ids) + log.debug("Restarting set of [{}] entities", ids.size) + context.actorOf( + RememberEntityStarter.props(context.parent, self, shardId, ids, settings), + "RememberEntitiesStarter") } context.parent ! ShardInitialized(shardId) context.become(idle) unstashAll() } - def restartRememberedEntities(ids: Set[EntityId]): Unit = { - log.debug( - "Shard starting [{}] remembered entities using strategy [{}]", - ids.size, - rememberedEntitiesRecoveryStrategy) - // FIXME Separation of concerns: shouldn't this future juggling be part of the RememberEntityStarter actor instead? - rememberedEntitiesRecoveryStrategy.recoverEntities(ids).foreach { scheduledRecovery => - scheduledRecovery - .filter(_.nonEmpty)(ExecutionContexts.parasitic) - .map(RestartRememberedEntities)(ExecutionContexts.parasitic) - .pipeTo(self) - } - } - - def restartEntities(ids: Set[EntityId]): Unit = { - log.debug("Restarting set of [{}] entities", ids.size) - context.actorOf( - RememberEntityStarter.props(context.parent, self, shardId, ids, settings), - "RememberEntitiesStarter") - } - // ===== shard up and running ===== // when not remembering entities, we stay in this state all the time @@ -739,7 +713,6 @@ private[akka] class Shard( s"Unexpected state for [$entityId] when getting RestartTerminatedEntity: [$other]") } - case RestartRememberedEntities(ids) => restartEntities(ids) case EntitiesMovedToOtherShard(movedEntityIds) => log.info( "Clearing [{}] remembered entities started elsewhere because of changed shard id extractor", @@ -1059,15 +1032,4 @@ private[akka] class Shard( log.debug("Shard [{}] shutting down", shardId) } - private def rememberedEntitiesRecoveryStrategy: EntityRecoveryStrategy = { - import settings.tuningParameters._ - entityRecoveryStrategy match { - case "all" => EntityRecoveryStrategy.allStrategy() - case "constant" => - EntityRecoveryStrategy.constantStrategy( - context.system, - entityRecoveryConstantRateStrategyFrequency, - entityRecoveryConstantRateStrategyNumberOfEntities) - } - } } diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala deleted file mode 100644 index b00a756bbc..0000000000 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/EntityRecoveryStrategy.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (C) 2009-2020 Lightbend Inc. - */ - -package akka.cluster.sharding.internal - -import akka.actor.ActorSystem -import akka.annotation.InternalApi -import akka.cluster.sharding.ShardRegion -import akka.util.PrettyDuration - -import scala.collection.immutable.Set -import scala.concurrent.Future -import scala.concurrent.duration.FiniteDuration - -/** - * INTERNAL API - */ -@InternalApi -private[akka] object EntityRecoveryStrategy { - def allStrategy(): EntityRecoveryStrategy = new AllAtOnceEntityRecoveryStrategy() - - def constantStrategy( - actorSystem: ActorSystem, - frequency: FiniteDuration, - numberOfEntities: Int): EntityRecoveryStrategy = - new ConstantRateEntityRecoveryStrategy(actorSystem, frequency, numberOfEntities) -} - -/** - * INTERNAL API - */ -@InternalApi -private[akka] trait EntityRecoveryStrategy { - - import ShardRegion.EntityId - - import scala.concurrent.Future - - def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] -} - -/** - * INTERNAL API - */ -@InternalApi -private[akka] final class AllAtOnceEntityRecoveryStrategy extends EntityRecoveryStrategy { - - import ShardRegion.EntityId - - override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = - if (entities.isEmpty) Set.empty else Set(Future.successful(entities)) - - override def toString: EntityId = "AllAtOnceEntityRecoveryStrategy" -} - -final class ConstantRateEntityRecoveryStrategy( - actorSystem: ActorSystem, - frequency: FiniteDuration, - numberOfEntities: Int) - extends EntityRecoveryStrategy { - - import ShardRegion.EntityId - import actorSystem.dispatcher - import akka.pattern.after - - override def recoverEntities(entities: Set[EntityId]): Set[Future[Set[EntityId]]] = - entities - .grouped(numberOfEntities) - .foldLeft((frequency, Set[Future[Set[EntityId]]]())) { - case ((interval, scheduledEntityIds), entityIds) => - (interval + frequency, scheduledEntityIds + scheduleEntities(interval, entityIds)) - } - ._2 - - private def scheduleEntities(interval: FiniteDuration, entityIds: Set[EntityId]): Future[Set[EntityId]] = - after(interval, actorSystem.scheduler)(Future.successful[Set[EntityId]](entityIds)) - - override def toString: EntityId = - s"ConstantRateEntityRecoveryStrategy(${PrettyDuration.format(frequency)}, $numberOfEntities)" -} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala deleted file mode 100644 index c754a64b89..0000000000 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntitiesStarter.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (C) 2009-2020 Lightbend Inc. - */ - -package akka.cluster.sharding.internal - -import akka.actor.Actor -import akka.actor.ActorLogging -import akka.actor.ActorRef -import akka.actor.NoSerializationVerificationNeeded -import akka.actor.Props -import akka.actor.Timers -import akka.annotation.InternalApi -import akka.cluster.sharding.ClusterShardingSettings -import akka.cluster.sharding.Shard -import akka.cluster.sharding.ShardRegion - -import scala.collection.immutable.Set - -/** - * INTERNAL API - */ -@InternalApi -private[akka] object RememberEntityStarter { - def props( - region: ActorRef, - shard: ActorRef, - shardId: ShardRegion.ShardId, - ids: Set[ShardRegion.EntityId], - settings: ClusterShardingSettings) = - Props(new RememberEntityStarter(region, shard, shardId, ids, settings)) - - private case object Tick extends NoSerializationVerificationNeeded -} - -/** - * INTERNAL API: Actor responsible for starting entities when rememberEntities is enabled - */ -@InternalApi -private[akka] final class RememberEntityStarter( - region: ActorRef, - shard: ActorRef, - shardId: ShardRegion.ShardId, - ids: Set[ShardRegion.EntityId], - settings: ClusterShardingSettings) - extends Actor - with ActorLogging - with Timers { - - import RememberEntityStarter.Tick - - private var waitingForAck = ids - private var entitiesMoved = Set.empty[ShardRegion.ShardId] - - sendStart(ids) - - val tickTask = { - val resendInterval = settings.tuningParameters.retryInterval - timers.startTimerWithFixedDelay(Tick, Tick, resendInterval) - } - - def sendStart(ids: Set[ShardRegion.EntityId]): Unit = { - // these go through the region rather the directly to the shard - // so that shard mapping changes are picked up - ids.foreach(id => region ! ShardRegion.StartEntity(id)) - } - - override def receive: Receive = { - case ShardRegion.StartEntityAck(entityId, ackFromShardId) => - waitingForAck -= entityId - if (shardId != ackFromShardId) entitiesMoved += entityId - if (waitingForAck.isEmpty) { - if (entitiesMoved.nonEmpty) shard ! Shard.EntitiesMovedToOtherShard(ids) - context.stop(self) - } - - case Tick => - sendStart(waitingForAck) - - } -} diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntityStarter.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntityStarter.scala new file mode 100644 index 0000000000..f17749fffe --- /dev/null +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/internal/RememberEntityStarter.scala @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2009-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.ActorRef +import akka.actor.NoSerializationVerificationNeeded +import akka.actor.Props +import akka.actor.Timers +import akka.annotation.InternalApi +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.Shard +import akka.cluster.sharding.ShardRegion +import akka.cluster.sharding.ShardRegion.EntityId +import akka.cluster.sharding.ShardRegion.ShardId + +import scala.collection.immutable.Set +import scala.concurrent.ExecutionContext + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object RememberEntityStarter { + def props( + region: ActorRef, + shard: ActorRef, + shardId: ShardRegion.ShardId, + ids: Set[ShardRegion.EntityId], + settings: ClusterShardingSettings) = + Props(new RememberEntityStarter(region, shard, shardId, ids, settings)) + + private final case class StartBatch(batchSize: Int) extends NoSerializationVerificationNeeded + private case object ResendUnAcked extends NoSerializationVerificationNeeded +} + +/** + * INTERNAL API: Actor responsible for starting entities when rememberEntities is enabled + */ +@InternalApi +private[akka] final class RememberEntityStarter( + region: ActorRef, + shard: ActorRef, + shardId: ShardRegion.ShardId, + ids: Set[ShardRegion.EntityId], + settings: ClusterShardingSettings) + extends Actor + with ActorLogging + with Timers { + + implicit val ec: ExecutionContext = context.dispatcher + import RememberEntityStarter._ + + private var idsLeftToStart = Set.empty[EntityId] + private var waitingForAck = Set.empty[EntityId] + private var entitiesMoved = Set.empty[EntityId] + + log.debug( + "Shard starting [{}] remembered entities using strategy [{}]", + ids.size, + settings.tuningParameters.entityRecoveryStrategy) + + settings.tuningParameters.entityRecoveryStrategy match { + case "all" => + idsLeftToStart = Set.empty + startBatch(ids) + case "constant" => + import settings.tuningParameters + idsLeftToStart = ids + timers.startTimerWithFixedDelay( + "constant", + StartBatch(tuningParameters.entityRecoveryConstantRateStrategyNumberOfEntities), + tuningParameters.entityRecoveryConstantRateStrategyFrequency) + startBatch(tuningParameters.entityRecoveryConstantRateStrategyNumberOfEntities) + } + timers.startTimerWithFixedDelay("retry", ResendUnAcked, settings.tuningParameters.retryInterval) + + override def receive: Receive = { + case StartBatch(batchSize) => startBatch(batchSize) + case ShardRegion.StartEntityAck(entityId, ackFromShardId) => onAck(entityId, ackFromShardId) + case ResendUnAcked => retryUnacked() + } + + private def onAck(entityId: EntityId, ackFromShardId: ShardId): Unit = { + idsLeftToStart -= entityId + waitingForAck -= entityId + if (shardId != ackFromShardId) entitiesMoved += entityId + if (waitingForAck.isEmpty && idsLeftToStart.isEmpty) { + if (entitiesMoved.nonEmpty) { + log.info("Found [{}] entities moved to new shard(s)", entitiesMoved.size) + shard ! Shard.EntitiesMovedToOtherShard(entitiesMoved) + } + context.stop(self) + } + } + + private def startBatch(batchSize: Int): Unit = { + log.debug("Starting batch of [{}] remembered entities", batchSize) + val (batch, newIdsLeftToStart) = idsLeftToStart.splitAt(batchSize) + idsLeftToStart = newIdsLeftToStart + startBatch(batch) + } + + private def startBatch(entityIds: Set[EntityId]): Unit = { + // these go through the region rather the directly to the shard + // so that shard id extractor changes make them start on the right shard + waitingForAck = waitingForAck.union(entityIds) + entityIds.foreach(entityId => region ! ShardRegion.StartEntity(entityId)) + } + + private def retryUnacked(): Unit = { + if (waitingForAck.nonEmpty) { + log.debug("Found [{}] remembered entities waiting for StartEntityAck, retrying", waitingForAck.size) + waitingForAck.foreach { id => + // for now we just retry all (as that was the existing behavior spread out over starter and shard) + // but in the future it could perhaps make sense to batch also the retries to avoid thundering herd + region ! ShardRegion.StartEntity(id) + } + } + } + +} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala deleted file mode 100644 index c3bc9c804b..0000000000 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/AllAtOnceEntityRecoveryStrategySpec.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (C) 2018-2020 Lightbend Inc. - */ - -package akka.cluster.sharding - -import akka.cluster.sharding.ShardRegion.EntityId -import akka.cluster.sharding.internal.EntityRecoveryStrategy -import akka.testkit.AkkaSpec - -class AllAtOnceEntityRecoveryStrategySpec extends AkkaSpec { - val strategy = EntityRecoveryStrategy.allStrategy() - - "AllAtOnceEntityRecoveryStrategy" must { - "recover entities" in { - val entities = Set[EntityId]("1", "2", "3", "4", "5") - val result = strategy.recoverEntities(entities) - result.size should ===(1) - // the Future is completed immediately for allStrategy - result.head.value.get.get should ===(entities) - } - - "not recover when no entities to recover" in { - val result = strategy.recoverEntities(Set[EntityId]()) - result.size should ===(0) - } - } -} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala deleted file mode 100644 index 8083a5d010..0000000000 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (C) 2018-2020 Lightbend Inc. - */ - -package akka.cluster.sharding - -import scala.concurrent.{ Await, Future } -import scala.concurrent.duration._ - -import akka.cluster.sharding.internal.EntityRecoveryStrategy -import akka.cluster.sharding.ShardRegion.EntityId -import akka.testkit.{ AkkaSpec, TimingTest } - -class ConstantRateEntityRecoveryStrategySpec extends AkkaSpec { - - val strategy = EntityRecoveryStrategy.constantStrategy(system, 1.second, 2) - "ConstantRateEntityRecoveryStrategy" must { - "recover entities" taggedAs TimingTest in { - import system.dispatcher - val entities = Set[EntityId]("1", "2", "3", "4", "5") - val startTime = System.nanoTime() - val resultWithTimes = - strategy.recoverEntities(entities).map(_.map(entityIds => entityIds -> (System.nanoTime() - startTime).nanos)) - - val result = - Await.result(Future.sequence(resultWithTimes), 6.seconds).toVector.sortBy { case (_, duration) => duration } - result.size should ===(3) - - val scheduledEntities = result.map(_._1) - scheduledEntities(0).size should ===(2) - scheduledEntities(1).size should ===(2) - scheduledEntities(2).size should ===(1) - scheduledEntities.flatten.toSet should ===(entities) - - val timesMillis = result.map(_._2.toMillis) - - // scheduling will not happen too early - timesMillis(0) should ===(1400L +- 500) - timesMillis(1) should ===(2400L +- 500L) - timesMillis(2) should ===(3400L +- 500L) - } - - "not recover when no entities to recover" in { - val result = strategy.recoverEntities(Set[EntityId]()) - result.size should ===(0) - } - } -} diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/internal/RememberEntitiesStarterSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/internal/RememberEntitiesStarterSpec.scala new file mode 100644 index 0000000000..7637fc1b78 --- /dev/null +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/internal/RememberEntitiesStarterSpec.scala @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2018-2020 Lightbend Inc. + */ + +package akka.cluster.sharding.internal + +import akka.cluster.sharding.ClusterShardingSettings +import akka.cluster.sharding.Shard +import akka.cluster.sharding.ShardRegion +import akka.cluster.sharding.ShardRegion.ShardId +import akka.testkit.AkkaSpec +import akka.testkit.TestProbe +import com.typesafe.config.ConfigFactory + +import scala.concurrent.duration._ + +class RememberEntitiesStarterSpec extends AkkaSpec { + + var shardIdCounter = 1 + def nextShardId(): ShardId = { + val id = s"ShardId$shardIdCounter" + shardIdCounter += 1 + id + } + + "The RememberEntitiesStarter" must { + "try start all entities directly with entity-recovery-strategy = all (default)" in { + val regionProbe = TestProbe() + val shardProbe = TestProbe() + val shardId = nextShardId() + + val defaultSettings = ClusterShardingSettings(system) + + val rememberEntityStarter = system.actorOf( + RememberEntityStarter.props(regionProbe.ref, shardProbe.ref, shardId, Set("1", "2", "3"), defaultSettings)) + + watch(rememberEntityStarter) + val startedEntityIds = (1 to 3).map { _ => + val start = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start.entityId, shardId) + start.entityId + }.toSet + startedEntityIds should ===(Set("1", "2", "3")) + + // the starter should then stop itself, not sending anything more to the shard or region + expectTerminated(rememberEntityStarter) + shardProbe.expectNoMessage() + regionProbe.expectNoMessage() + } + + "retry start all entities with no ack with entity-recovery-strategy = all (default)" in { + val regionProbe = TestProbe() + val shardProbe = TestProbe() + val shardId = nextShardId() + + val customSettings = ClusterShardingSettings( + ConfigFactory + .parseString( + // the restarter somewhat surprisingly uses this for no-ack-retry. Tune it down to speed up test + """ + retry-interval = 1 second + """) + .withFallback(system.settings.config.getConfig("akka.cluster.sharding"))) + + val rememberEntityStarter = system.actorOf( + RememberEntityStarter.props(regionProbe.ref, shardProbe.ref, shardId, Set("1", "2", "3"), customSettings)) + + watch(rememberEntityStarter) + (1 to 3).foreach { _ => + regionProbe.expectMsgType[ShardRegion.StartEntity] + } + val startedOnSecondTry = (1 to 3).map { _ => + val start = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start.entityId, shardId) + start.entityId + }.toSet + startedOnSecondTry should ===(Set("1", "2", "3")) + + // should stop itself, not sending anything to the shard + expectTerminated(rememberEntityStarter) + shardProbe.expectNoMessage() + } + + "inform the shard when entities has been reallocated to different shard id" in { + val regionProbe = TestProbe() + val shardProbe = TestProbe() + val shardId = nextShardId() + + val customSettings = ClusterShardingSettings( + ConfigFactory + .parseString( + // the restarter somewhat surprisingly uses this for no-ack-retry. Tune it down to speed up test + """ + retry-interval = 1 second + """) + .withFallback(system.settings.config.getConfig("akka.cluster.sharding"))) + + val rememberEntityStarter = system.actorOf( + RememberEntityStarter.props(regionProbe.ref, shardProbe.ref, shardId, Set("1", "2", "3"), customSettings)) + + watch(rememberEntityStarter) + val start1 = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start1.entityId, shardId) // keep on current shard + + val start2 = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start2.entityId, shardId = "Relocated1") + + val start3 = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start3.entityId, shardId = "Relocated2") + + shardProbe.expectMsg(Shard.EntitiesMovedToOtherShard(Set("2", "3"))) + expectTerminated(rememberEntityStarter) + } + + "try start all entities in a throttled way with entity-recovery-strategy = constant" in { + val regionProbe = TestProbe() + val shardProbe = TestProbe() + val shardId = nextShardId() + + val customSettings = ClusterShardingSettings( + ConfigFactory + .parseString( + // slow constant restart + """ + entity-recovery-strategy = constant + entity-recovery-constant-rate-strategy { + frequency = 2 s + number-of-entities = 2 + } + retry-interval = 1 second + """) + .withFallback(system.settings.config.getConfig("akka.cluster.sharding"))) + + val rememberEntityStarter = system.actorOf( + RememberEntityStarter + .props(regionProbe.ref, shardProbe.ref, shardId, Set("1", "2", "3", "4", "5"), customSettings)) + + def recieveStartAndAck() = { + val start = regionProbe.expectMsgType[ShardRegion.StartEntity] + regionProbe.lastSender ! ShardRegion.StartEntityAck(start.entityId, shardId) + } + + watch(rememberEntityStarter) + // first batch should be immediate + recieveStartAndAck() + recieveStartAndAck() + // second batch holding off (with some room for unstable test env) + regionProbe.expectNoMessage(600.millis) + + // second batch should be immediate + recieveStartAndAck() + recieveStartAndAck() + // third batch holding off + regionProbe.expectNoMessage(600.millis) + + recieveStartAndAck() + + // the starter should then stop itself, not sending anything more to the shard or region + expectTerminated(rememberEntityStarter) + shardProbe.expectNoMessage() + regionProbe.expectNoMessage() + } + + } +}