From 985510d6acfa3f09a4a08dd5ec142a1ecb3021ed Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Mon, 22 Jun 2015 08:54:42 +0200 Subject: [PATCH] =cls #17261 Use persistent shard only when rememberEntities=true --- .../scala/akka/cluster/sharding/Shard.scala | 182 ++++++++++++------ 1 file changed, 120 insertions(+), 62 deletions(-) diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala index f65be0fe96..7df2ce0c40 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala @@ -12,6 +12,8 @@ import akka.actor.Terminated import akka.cluster.sharding.Shard.ShardCommand import akka.persistence.PersistentActor import akka.persistence.SnapshotOffer +import akka.actor.Actor +import akka.persistence.RecoveryCompleted /** * INTERNAL API @@ -58,6 +60,8 @@ private[akka] object Shard { /** * Factory method for the [[akka.actor.Props]] of the [[Shard]] actor. + * If `settings.rememberEntities` is enabled the `PersistentShard` + * subclass is used, otherwise `Shard`. */ def props(typeName: String, shardId: ShardRegion.ShardId, @@ -65,9 +69,14 @@ private[akka] object Shard { settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, - handOffStopMessage: Any): Props = - Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)) - .withDeploy(Deploy.local) + handOffStopMessage: Any): Props = { + if (settings.rememberEntities) + Props(new PersistentShard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)) + .withDeploy(Deploy.local) + else + Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)) + .withDeploy(Deploy.local) + } } /** @@ -85,57 +94,31 @@ private[akka] class Shard( settings: ClusterShardingSettings, extractEntityId: ShardRegion.ExtractEntityId, extractShardId: ShardRegion.ExtractShardId, - handOffStopMessage: Any) extends PersistentActor with ActorLogging { + handOffStopMessage: Any) extends Actor with ActorLogging { import ShardRegion.{ handOffStopperProps, EntityId, Msg, Passivate } import ShardCoordinator.Internal.{ HandOff, ShardStopped } import Shard.{ State, RestartEntity, EntityStopped, EntityStarted } import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage import akka.cluster.sharding.ShardRegion.ShardRegionCommand - import akka.persistence.RecoveryCompleted - - import settings.rememberEntities import settings.tuningParameters._ - override def persistenceId = s"/sharding/${typeName}Shard/${shardId}" - - override def journalPluginId: String = settings.journalPluginId - - override def snapshotPluginId: String = settings.snapshotPluginId - var state = State.Empty var idByRef = Map.empty[ActorRef, EntityId] var refById = Map.empty[EntityId, ActorRef] var passivating = Set.empty[ActorRef] var messageBuffers = Map.empty[EntityId, Vector[(Msg, ActorRef)]] - var persistCount = 0 var handOffStopper: Option[ActorRef] = None def totalBufferSize = messageBuffers.foldLeft(0) { (sum, entity) ⇒ sum + entity._2.size } def processChange[A](event: A)(handler: A ⇒ Unit): Unit = - if (rememberEntities) { - saveSnapshotWhenNeeded() - persist(event)(handler) - } else handler(event) + handler(event) - def saveSnapshotWhenNeeded(): Unit = { - persistCount += 1 - if (persistCount % snapshotAfter == 0) { - log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr) - saveSnapshot(state) - } - } + def receive = receiveCommand - override def receiveRecover: Receive = { - case EntityStarted(id) if rememberEntities ⇒ state = state.copy(state.entities + id) - case EntityStopped(id) if rememberEntities ⇒ state = state.copy(state.entities - id) - case SnapshotOffer(_, snapshot: State) ⇒ state = snapshot - case RecoveryCompleted ⇒ state.entities foreach getEntity - } - - override def receiveCommand: Receive = { + def receiveCommand: Receive = { case Terminated(ref) ⇒ receiveTerminated(ref) case msg: CoordinatorMessage ⇒ receiveCoordinatorMessage(msg) case msg: ShardCommand ⇒ receiveShardCommand(msg) @@ -178,25 +161,22 @@ private[akka] class Shard( } def receiveTerminated(ref: ActorRef): Unit = { - if (handOffStopper.exists(_ == ref)) { + if (handOffStopper.exists(_ == ref)) context stop self - } else if (idByRef.contains(ref) && handOffStopper.isEmpty) { - val id = idByRef(ref) - if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) { - //Note; because we're not persisting the EntityStopped, we don't need - // to persist the EntityStarted either. - log.debug("Starting entity [{}] again, there are buffered messages for it", id) - sendMsgBuffer(EntityStarted(id)) - } else { - if (rememberEntities && !passivating.contains(ref)) { - log.debug("Entity [{}] stopped without passivating, will restart after backoff", id) - import context.dispatcher - context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id)) - } else processChange(EntityStopped(id))(passivateCompleted) - } + else if (idByRef.contains(ref) && handOffStopper.isEmpty) + entityTerminated(ref) + } - passivating = passivating - ref + def entityTerminated(ref: ActorRef): Unit = { + val id = idByRef(ref) + if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) { + log.debug("Starting entity [{}] again, there are buffered messages for it", id) + sendMsgBuffer(EntityStarted(id)) + } else { + processChange(EntityStopped(id))(passivateCompleted) } + + passivating = passivating - ref } def passivate(entity: ActorRef, stopMessage: Any): Unit = { @@ -212,7 +192,7 @@ private[akka] class Shard( } } - // EntityStopped persistence handler + // EntityStopped handler def passivateCompleted(event: EntityStopped): Unit = { log.debug("Entity stopped [{}]", event.entityId) @@ -224,7 +204,7 @@ private[akka] class Shard( messageBuffers = messageBuffers - event.entityId } - // EntityStarted persistence handler + // EntityStarted handler def sendMsgBuffer(event: EntityStarted): Unit = { //Get the buffered messages and remove the buffer val messages = messageBuffers.getOrElse(event.entityId, Vector.empty) @@ -265,17 +245,8 @@ private[akka] class Shard( def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = { val name = URLEncoder.encode(id, "utf-8") context.child(name) match { - case Some(actor) ⇒ - actor.tell(payload, snd) - - case None if rememberEntities ⇒ - //Note; we only do this if remembering, otherwise the buffer is an overhead - messageBuffers = messageBuffers.updated(id, Vector((msg, snd))) - saveSnapshotWhenNeeded() - persist(EntityStarted(id))(sendMsgBuffer) - - case None ⇒ - getEntity(id).tell(payload, snd) + case Some(actor) ⇒ actor.tell(payload, snd) + case None ⇒ getEntity(id).tell(payload, snd) } } @@ -292,3 +263,90 @@ private[akka] class Shard( } } } + +/** + * INTERNAL API + * + * This actor creates children entity actors on demand that it is told to be + * responsible for. It is used when `rememberEntities` is enabled. + * + * @see [[ClusterSharding$ ClusterSharding extension]] + */ +private[akka] class PersistentShard( + typeName: String, + shardId: ShardRegion.ShardId, + entityProps: Props, + settings: ClusterShardingSettings, + extractEntityId: ShardRegion.ExtractEntityId, + extractShardId: ShardRegion.ExtractShardId, + handOffStopMessage: Any) extends Shard( + typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage) + with PersistentActor with ActorLogging { + + import ShardRegion.{ EntityId, Msg } + import Shard.{ State, RestartEntity, EntityStopped, EntityStarted } + import settings.tuningParameters._ + + override def persistenceId = s"/sharding/${typeName}Shard/${shardId}" + + override def journalPluginId: String = settings.journalPluginId + + override def snapshotPluginId: String = settings.snapshotPluginId + + var persistCount = 0 + + override def receive = receiveCommand + + override def processChange[A](event: A)(handler: A ⇒ Unit): Unit = { + saveSnapshotWhenNeeded() + persist(event)(handler) + } + + def saveSnapshotWhenNeeded(): Unit = { + persistCount += 1 + if (persistCount % snapshotAfter == 0) { + log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr) + saveSnapshot(state) + } + } + + override def receiveRecover: Receive = { + case EntityStarted(id) ⇒ state = state.copy(state.entities + id) + case EntityStopped(id) ⇒ state = state.copy(state.entities - id) + case SnapshotOffer(_, snapshot: State) ⇒ state = snapshot + case RecoveryCompleted ⇒ state.entities foreach getEntity + } + + override def entityTerminated(ref: ActorRef): Unit = { + val id = idByRef(ref) + if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) { + //Note; because we're not persisting the EntityStopped, we don't need + // to persist the EntityStarted either. + log.debug("Starting entity [{}] again, there are buffered messages for it", id) + sendMsgBuffer(EntityStarted(id)) + } else { + if (!passivating.contains(ref)) { + log.debug("Entity [{}] stopped without passivating, will restart after backoff", id) + import context.dispatcher + context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id)) + } else processChange(EntityStopped(id))(passivateCompleted) + } + + passivating = passivating - ref + } + + override def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = { + val name = URLEncoder.encode(id, "utf-8") + context.child(name) match { + case Some(actor) ⇒ + actor.tell(payload, snd) + + case None ⇒ + //Note; we only do this if remembering, otherwise the buffer is an overhead + messageBuffers = messageBuffers.updated(id, Vector((msg, snd))) + saveSnapshotWhenNeeded() + persist(EntityStarted(id))(sendMsgBuffer) + } + } + +}