diff --git a/akka-cluster-sharding/src/main/java/akka/cluster/sharding/protobuf/msg/ClusterShardingMessages.java b/akka-cluster-sharding/src/main/java/akka/cluster/sharding/protobuf/msg/ClusterShardingMessages.java
index 9bb12900f9..18170e75d7 100644
--- a/akka-cluster-sharding/src/main/java/akka/cluster/sharding/protobuf/msg/ClusterShardingMessages.java
+++ b/akka-cluster-sharding/src/main/java/akka/cluster/sharding/protobuf/msg/ClusterShardingMessages.java
@@ -5591,6 +5591,574 @@ public final class ClusterShardingMessages {
// @@protoc_insertion_point(class_scope:EntityStopped)
}
+ public interface ShardStatsOrBuilder
+ extends akka.protobuf.MessageOrBuilder {
+
+ // required string shard = 1;
+ /**
+ * required string shard = 1;
+ */
+ boolean hasShard();
+ /**
+ * required string shard = 1;
+ */
+ java.lang.String getShard();
+ /**
+ * required string shard = 1;
+ */
+ akka.protobuf.ByteString
+ getShardBytes();
+
+ // required int32 entityCount = 2;
+ /**
+ * required int32 entityCount = 2;
+ */
+ boolean hasEntityCount();
+ /**
+ * required int32 entityCount = 2;
+ */
+ int getEntityCount();
+ }
+ /**
+ * Protobuf type {@code ShardStats}
+ */
+ public static final class ShardStats extends
+ akka.protobuf.GeneratedMessage
+ implements ShardStatsOrBuilder {
+ // Use ShardStats.newBuilder() to construct.
+ private ShardStats(akka.protobuf.GeneratedMessage.Builder> builder) {
+ super(builder);
+ this.unknownFields = builder.getUnknownFields();
+ }
+ private ShardStats(boolean noInit) { this.unknownFields = akka.protobuf.UnknownFieldSet.getDefaultInstance(); }
+
+ private static final ShardStats defaultInstance;
+ public static ShardStats getDefaultInstance() {
+ return defaultInstance;
+ }
+
+ public ShardStats getDefaultInstanceForType() {
+ return defaultInstance;
+ }
+
+ private final akka.protobuf.UnknownFieldSet unknownFields;
+ @java.lang.Override
+ public final akka.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
+ private ShardStats(
+ akka.protobuf.CodedInputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ initFields();
+ int mutable_bitField0_ = 0;
+ akka.protobuf.UnknownFieldSet.Builder unknownFields =
+ akka.protobuf.UnknownFieldSet.newBuilder();
+ try {
+ boolean done = false;
+ while (!done) {
+ int tag = input.readTag();
+ switch (tag) {
+ case 0:
+ done = true;
+ break;
+ default: {
+ if (!parseUnknownField(input, unknownFields,
+ extensionRegistry, tag)) {
+ done = true;
+ }
+ break;
+ }
+ case 10: {
+ bitField0_ |= 0x00000001;
+ shard_ = input.readBytes();
+ break;
+ }
+ case 16: {
+ bitField0_ |= 0x00000002;
+ entityCount_ = input.readInt32();
+ break;
+ }
+ }
+ }
+ } catch (akka.protobuf.InvalidProtocolBufferException e) {
+ throw e.setUnfinishedMessage(this);
+ } catch (java.io.IOException e) {
+ throw new akka.protobuf.InvalidProtocolBufferException(
+ e.getMessage()).setUnfinishedMessage(this);
+ } finally {
+ this.unknownFields = unknownFields.build();
+ makeExtensionsImmutable();
+ }
+ }
+ public static final akka.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.internal_static_ShardStats_descriptor;
+ }
+
+ protected akka.protobuf.GeneratedMessage.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.internal_static_ShardStats_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.class, akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.Builder.class);
+ }
+
+ public static akka.protobuf.Parser PARSER =
+ new akka.protobuf.AbstractParser() {
+ public ShardStats parsePartialFrom(
+ akka.protobuf.CodedInputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ return new ShardStats(input, extensionRegistry);
+ }
+ };
+
+ @java.lang.Override
+ public akka.protobuf.Parser getParserForType() {
+ return PARSER;
+ }
+
+ private int bitField0_;
+ // required string shard = 1;
+ public static final int SHARD_FIELD_NUMBER = 1;
+ private java.lang.Object shard_;
+ /**
+ * required string shard = 1;
+ */
+ public boolean hasShard() {
+ return ((bitField0_ & 0x00000001) == 0x00000001);
+ }
+ /**
+ * required string shard = 1;
+ */
+ public java.lang.String getShard() {
+ java.lang.Object ref = shard_;
+ if (ref instanceof java.lang.String) {
+ return (java.lang.String) ref;
+ } else {
+ akka.protobuf.ByteString bs =
+ (akka.protobuf.ByteString) ref;
+ java.lang.String s = bs.toStringUtf8();
+ if (bs.isValidUtf8()) {
+ shard_ = s;
+ }
+ return s;
+ }
+ }
+ /**
+ * required string shard = 1;
+ */
+ public akka.protobuf.ByteString
+ getShardBytes() {
+ java.lang.Object ref = shard_;
+ if (ref instanceof java.lang.String) {
+ akka.protobuf.ByteString b =
+ akka.protobuf.ByteString.copyFromUtf8(
+ (java.lang.String) ref);
+ shard_ = b;
+ return b;
+ } else {
+ return (akka.protobuf.ByteString) ref;
+ }
+ }
+
+ // required int32 entityCount = 2;
+ public static final int ENTITYCOUNT_FIELD_NUMBER = 2;
+ private int entityCount_;
+ /**
+ * required int32 entityCount = 2;
+ */
+ public boolean hasEntityCount() {
+ return ((bitField0_ & 0x00000002) == 0x00000002);
+ }
+ /**
+ * required int32 entityCount = 2;
+ */
+ public int getEntityCount() {
+ return entityCount_;
+ }
+
+ private void initFields() {
+ shard_ = "";
+ entityCount_ = 0;
+ }
+ private byte memoizedIsInitialized = -1;
+ public final boolean isInitialized() {
+ byte isInitialized = memoizedIsInitialized;
+ if (isInitialized != -1) return isInitialized == 1;
+
+ if (!hasShard()) {
+ memoizedIsInitialized = 0;
+ return false;
+ }
+ if (!hasEntityCount()) {
+ memoizedIsInitialized = 0;
+ return false;
+ }
+ memoizedIsInitialized = 1;
+ return true;
+ }
+
+ public void writeTo(akka.protobuf.CodedOutputStream output)
+ throws java.io.IOException {
+ getSerializedSize();
+ if (((bitField0_ & 0x00000001) == 0x00000001)) {
+ output.writeBytes(1, getShardBytes());
+ }
+ if (((bitField0_ & 0x00000002) == 0x00000002)) {
+ output.writeInt32(2, entityCount_);
+ }
+ getUnknownFields().writeTo(output);
+ }
+
+ private int memoizedSerializedSize = -1;
+ public int getSerializedSize() {
+ int size = memoizedSerializedSize;
+ if (size != -1) return size;
+
+ size = 0;
+ if (((bitField0_ & 0x00000001) == 0x00000001)) {
+ size += akka.protobuf.CodedOutputStream
+ .computeBytesSize(1, getShardBytes());
+ }
+ if (((bitField0_ & 0x00000002) == 0x00000002)) {
+ size += akka.protobuf.CodedOutputStream
+ .computeInt32Size(2, entityCount_);
+ }
+ size += getUnknownFields().getSerializedSize();
+ memoizedSerializedSize = size;
+ return size;
+ }
+
+ private static final long serialVersionUID = 0L;
+ @java.lang.Override
+ protected java.lang.Object writeReplace()
+ throws java.io.ObjectStreamException {
+ return super.writeReplace();
+ }
+
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ akka.protobuf.ByteString data)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ akka.protobuf.ByteString data,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(byte[] data)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ byte[] data,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws akka.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return PARSER.parseFrom(input);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ java.io.InputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return PARSER.parseFrom(input, extensionRegistry);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseDelimitedFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return PARSER.parseDelimitedFrom(input);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseDelimitedFrom(
+ java.io.InputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return PARSER.parseDelimitedFrom(input, extensionRegistry);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ akka.protobuf.CodedInputStream input)
+ throws java.io.IOException {
+ return PARSER.parseFrom(input);
+ }
+ public static akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parseFrom(
+ akka.protobuf.CodedInputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return PARSER.parseFrom(input, extensionRegistry);
+ }
+
+ public static Builder newBuilder() { return Builder.create(); }
+ public Builder newBuilderForType() { return newBuilder(); }
+ public static Builder newBuilder(akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats prototype) {
+ return newBuilder().mergeFrom(prototype);
+ }
+ public Builder toBuilder() { return newBuilder(this); }
+
+ @java.lang.Override
+ protected Builder newBuilderForType(
+ akka.protobuf.GeneratedMessage.BuilderParent parent) {
+ Builder builder = new Builder(parent);
+ return builder;
+ }
+ /**
+ * Protobuf type {@code ShardStats}
+ */
+ public static final class Builder extends
+ akka.protobuf.GeneratedMessage.Builder
+ implements akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStatsOrBuilder {
+ public static final akka.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.internal_static_ShardStats_descriptor;
+ }
+
+ protected akka.protobuf.GeneratedMessage.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.internal_static_ShardStats_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.class, akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.Builder.class);
+ }
+
+ // Construct using akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.newBuilder()
+ private Builder() {
+ maybeForceBuilderInitialization();
+ }
+
+ private Builder(
+ akka.protobuf.GeneratedMessage.BuilderParent parent) {
+ super(parent);
+ maybeForceBuilderInitialization();
+ }
+ private void maybeForceBuilderInitialization() {
+ if (akka.protobuf.GeneratedMessage.alwaysUseFieldBuilders) {
+ }
+ }
+ private static Builder create() {
+ return new Builder();
+ }
+
+ public Builder clear() {
+ super.clear();
+ shard_ = "";
+ bitField0_ = (bitField0_ & ~0x00000001);
+ entityCount_ = 0;
+ bitField0_ = (bitField0_ & ~0x00000002);
+ return this;
+ }
+
+ public Builder clone() {
+ return create().mergeFrom(buildPartial());
+ }
+
+ public akka.protobuf.Descriptors.Descriptor
+ getDescriptorForType() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.internal_static_ShardStats_descriptor;
+ }
+
+ public akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats getDefaultInstanceForType() {
+ return akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.getDefaultInstance();
+ }
+
+ public akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats build() {
+ akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats result = buildPartial();
+ if (!result.isInitialized()) {
+ throw newUninitializedMessageException(result);
+ }
+ return result;
+ }
+
+ public akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats buildPartial() {
+ akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats result = new akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats(this);
+ int from_bitField0_ = bitField0_;
+ int to_bitField0_ = 0;
+ if (((from_bitField0_ & 0x00000001) == 0x00000001)) {
+ to_bitField0_ |= 0x00000001;
+ }
+ result.shard_ = shard_;
+ if (((from_bitField0_ & 0x00000002) == 0x00000002)) {
+ to_bitField0_ |= 0x00000002;
+ }
+ result.entityCount_ = entityCount_;
+ result.bitField0_ = to_bitField0_;
+ onBuilt();
+ return result;
+ }
+
+ public Builder mergeFrom(akka.protobuf.Message other) {
+ if (other instanceof akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats) {
+ return mergeFrom((akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats)other);
+ } else {
+ super.mergeFrom(other);
+ return this;
+ }
+ }
+
+ public Builder mergeFrom(akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats other) {
+ if (other == akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats.getDefaultInstance()) return this;
+ if (other.hasShard()) {
+ bitField0_ |= 0x00000001;
+ shard_ = other.shard_;
+ onChanged();
+ }
+ if (other.hasEntityCount()) {
+ setEntityCount(other.getEntityCount());
+ }
+ this.mergeUnknownFields(other.getUnknownFields());
+ return this;
+ }
+
+ public final boolean isInitialized() {
+ if (!hasShard()) {
+
+ return false;
+ }
+ if (!hasEntityCount()) {
+
+ return false;
+ }
+ return true;
+ }
+
+ public Builder mergeFrom(
+ akka.protobuf.CodedInputStream input,
+ akka.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats parsedMessage = null;
+ try {
+ parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);
+ } catch (akka.protobuf.InvalidProtocolBufferException e) {
+ parsedMessage = (akka.cluster.sharding.protobuf.msg.ClusterShardingMessages.ShardStats) e.getUnfinishedMessage();
+ throw e;
+ } finally {
+ if (parsedMessage != null) {
+ mergeFrom(parsedMessage);
+ }
+ }
+ return this;
+ }
+ private int bitField0_;
+
+ // required string shard = 1;
+ private java.lang.Object shard_ = "";
+ /**
+ * required string shard = 1;
+ */
+ public boolean hasShard() {
+ return ((bitField0_ & 0x00000001) == 0x00000001);
+ }
+ /**
+ * required string shard = 1;
+ */
+ public java.lang.String getShard() {
+ java.lang.Object ref = shard_;
+ if (!(ref instanceof java.lang.String)) {
+ java.lang.String s = ((akka.protobuf.ByteString) ref)
+ .toStringUtf8();
+ shard_ = s;
+ return s;
+ } else {
+ return (java.lang.String) ref;
+ }
+ }
+ /**
+ * required string shard = 1;
+ */
+ public akka.protobuf.ByteString
+ getShardBytes() {
+ java.lang.Object ref = shard_;
+ if (ref instanceof String) {
+ akka.protobuf.ByteString b =
+ akka.protobuf.ByteString.copyFromUtf8(
+ (java.lang.String) ref);
+ shard_ = b;
+ return b;
+ } else {
+ return (akka.protobuf.ByteString) ref;
+ }
+ }
+ /**
+ * required string shard = 1;
+ */
+ public Builder setShard(
+ java.lang.String value) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ bitField0_ |= 0x00000001;
+ shard_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * required string shard = 1;
+ */
+ public Builder clearShard() {
+ bitField0_ = (bitField0_ & ~0x00000001);
+ shard_ = getDefaultInstance().getShard();
+ onChanged();
+ return this;
+ }
+ /**
+ * required string shard = 1;
+ */
+ public Builder setShardBytes(
+ akka.protobuf.ByteString value) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ bitField0_ |= 0x00000001;
+ shard_ = value;
+ onChanged();
+ return this;
+ }
+
+ // required int32 entityCount = 2;
+ private int entityCount_ ;
+ /**
+ * required int32 entityCount = 2;
+ */
+ public boolean hasEntityCount() {
+ return ((bitField0_ & 0x00000002) == 0x00000002);
+ }
+ /**
+ * required int32 entityCount = 2;
+ */
+ public int getEntityCount() {
+ return entityCount_;
+ }
+ /**
+ * required int32 entityCount = 2;
+ */
+ public Builder setEntityCount(int value) {
+ bitField0_ |= 0x00000002;
+ entityCount_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * required int32 entityCount = 2;
+ */
+ public Builder clearEntityCount() {
+ bitField0_ = (bitField0_ & ~0x00000002);
+ entityCount_ = 0;
+ onChanged();
+ return this;
+ }
+
+ // @@protoc_insertion_point(builder_scope:ShardStats)
+ }
+
+ static {
+ defaultInstance = new ShardStats(true);
+ defaultInstance.initFields();
+ }
+
+ // @@protoc_insertion_point(class_scope:ShardStats)
+ }
+
private static akka.protobuf.Descriptors.Descriptor
internal_static_CoordinatorState_descriptor;
private static
@@ -5636,6 +6204,11 @@ public final class ClusterShardingMessages {
private static
akka.protobuf.GeneratedMessage.FieldAccessorTable
internal_static_EntityStopped_fieldAccessorTable;
+ private static akka.protobuf.Descriptors.Descriptor
+ internal_static_ShardStats_descriptor;
+ private static
+ akka.protobuf.GeneratedMessage.FieldAccessorTable
+ internal_static_ShardStats_fieldAccessorTable;
public static akka.protobuf.Descriptors.FileDescriptor
getDescriptor() {
@@ -5657,7 +6230,9 @@ public final class ClusterShardingMessages {
"\t\022\016\n\006region\030\002 \002(\t\"\037\n\013EntityState\022\020\n\010enti",
"ties\030\001 \003(\t\"!\n\rEntityStarted\022\020\n\010entityId\030" +
"\001 \002(\t\"!\n\rEntityStopped\022\020\n\010entityId\030\001 \002(\t" +
- "B&\n\"akka.cluster.sharding.protobuf.msgH\001"
+ "\"0\n\nShardStats\022\r\n\005shard\030\001 \002(\t\022\023\n\013entityC" +
+ "ount\030\002 \002(\005B&\n\"akka.cluster.sharding.prot" +
+ "obuf.msgH\001"
};
akka.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner =
new akka.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() {
@@ -5718,6 +6293,12 @@ public final class ClusterShardingMessages {
akka.protobuf.GeneratedMessage.FieldAccessorTable(
internal_static_EntityStopped_descriptor,
new java.lang.String[] { "EntityId", });
+ internal_static_ShardStats_descriptor =
+ getDescriptor().getMessageTypes().get(8);
+ internal_static_ShardStats_fieldAccessorTable = new
+ akka.protobuf.GeneratedMessage.FieldAccessorTable(
+ internal_static_ShardStats_descriptor,
+ new java.lang.String[] { "Shard", "EntityCount", });
return null;
}
};
diff --git a/akka-cluster-sharding/src/main/protobuf/ClusterShardingMessages.proto b/akka-cluster-sharding/src/main/protobuf/ClusterShardingMessages.proto
index a93b60964c..9820e971fb 100644
--- a/akka-cluster-sharding/src/main/protobuf/ClusterShardingMessages.proto
+++ b/akka-cluster-sharding/src/main/protobuf/ClusterShardingMessages.proto
@@ -48,3 +48,7 @@ message EntityStopped {
required string entityId = 1;
}
+message ShardStats {
+ required string shard = 1;
+ required int32 entityCount = 2;
+}
diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala
index ff9a27387a..5adf71fc9c 100644
--- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala
+++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/Shard.scala
@@ -9,7 +9,7 @@ import akka.actor.ActorRef
import akka.actor.Deploy
import akka.actor.Props
import akka.actor.Terminated
-import akka.cluster.sharding.Shard.ShardCommand
+import akka.cluster.sharding.Shard.{ GetCurrentShardState, ShardCommand }
import akka.persistence.PersistentActor
import akka.persistence.SnapshotOffer
import akka.actor.Actor
@@ -40,6 +40,11 @@ private[akka] object Shard {
val entityId: EntityId
}
+ /**
+ * A query for information about the shard
+ */
+ sealed trait ShardQuery
+
/**
* `State` change for starting an entity in this `Shard`
*/
@@ -50,6 +55,14 @@ private[akka] object Shard {
*/
@SerialVersionUID(1L) final case class EntityStopped(entityId: EntityId) extends StateChange
+ @SerialVersionUID(1L) case object GetCurrentShardState extends ShardQuery
+
+ @SerialVersionUID(1L) final case class CurrentShardState(shardId: ShardRegion.ShardId, entityIds: Set[EntityId])
+
+ @SerialVersionUID(1L) case object GetShardStats extends ShardQuery
+
+ @SerialVersionUID(1L) final case class ShardStats(shardId: ShardRegion.ShardId, entityCount: Int)
+
object State {
val Empty = State()
}
@@ -101,6 +114,7 @@ private[akka] class Shard(
import ShardRegion.{ handOffStopperProps, EntityId, Msg, Passivate, ShardInitialized }
import ShardCoordinator.Internal.{ HandOff, ShardStopped }
import Shard.{ State, RestartEntity, EntityStopped, EntityStarted }
+ import Shard.{ ShardQuery, GetCurrentShardState, CurrentShardState, GetShardStats, ShardStats }
import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage
import akka.cluster.sharding.ShardRegion.ShardRegionCommand
import settings.tuningParameters._
@@ -129,6 +143,7 @@ private[akka] class Shard(
case msg: CoordinatorMessage ⇒ receiveCoordinatorMessage(msg)
case msg: ShardCommand ⇒ receiveShardCommand(msg)
case msg: ShardRegionCommand ⇒ receiveShardRegionCommand(msg)
+ case msg: ShardQuery ⇒ receiveShardQuery(msg)
case msg if extractEntityId.isDefinedAt(msg) ⇒ deliverMessage(msg, sender())
}
@@ -147,6 +162,11 @@ private[akka] class Shard(
case _ ⇒ unhandled(msg)
}
+ def receiveShardQuery(msg: ShardQuery): Unit = msg match {
+ case GetCurrentShardState ⇒ sender() ! CurrentShardState(shardId, refById.keySet)
+ case GetShardStats ⇒ sender() ! ShardStats(shardId, state.entities.size)
+ }
+
def handOff(replyTo: ActorRef): Unit = handOffStopper match {
case Some(_) ⇒ log.warning("HandOff shard [{}] received during existing handOff", shardId)
case None ⇒
diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala
index 1e7ff93997..abc86a01aa 100644
--- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala
+++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala
@@ -3,6 +3,8 @@
*/
package akka.cluster.sharding
+import akka.util.Timeout
+
import scala.collection.immutable
import scala.concurrent.Future
import scala.concurrent.duration._
@@ -17,7 +19,7 @@ import akka.cluster.ddata.LWWRegister
import akka.cluster.ddata.LWWRegisterKey
import akka.cluster.ddata.Replicator._
import akka.dispatch.ExecutionContexts
-import akka.pattern.pipe
+import akka.pattern.{ AskTimeoutException, pipe }
import akka.persistence._
/**
@@ -529,6 +531,25 @@ abstract class ShardCoordinator(typeName: String, settings: ClusterShardingSetti
case None ⇒
}
+ case ShardRegion.GetClusterShardingStats(waitMax) ⇒
+ import akka.pattern.ask
+ implicit val timeout: Timeout = waitMax
+ Future.sequence(aliveRegions.map { regionActor ⇒
+ (regionActor ? ShardRegion.GetShardRegionStats).mapTo[ShardRegion.ShardRegionStats]
+ .map(stats ⇒ regionActor -> stats)
+ }).map { allRegionStats ⇒
+ ShardRegion.ClusterShardingStats(allRegionStats.map {
+ case (region, stats) ⇒
+ val address: Address =
+ if (region == self) Cluster(context.system).selfAddress
+ else region.path.address
+
+ address -> stats
+ }.toMap)
+ }.recover {
+ case x: AskTimeoutException ⇒ ShardRegion.ClusterShardingStats(Map.empty)
+ }.pipeTo(sender())
+
case ShardHome(_, _) ⇒
//On rebalance, we send ourselves a GetShardHome message to reallocate a
// shard. This receive handles the "response" from that message. i.e. ignores it.
diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala
index a2b3329a73..d137c7f7aa 100644
--- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala
+++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala
@@ -4,17 +4,11 @@
package akka.cluster.sharding
import java.net.URLEncoder
-import scala.collection.immutable
-import akka.actor.Actor
-import akka.actor.ActorLogging
-import akka.actor.ActorRef
-import akka.actor.ActorSelection
-import akka.actor.Address
-import akka.actor.Deploy
-import akka.actor.PoisonPill
-import akka.actor.Props
-import akka.actor.RootActorPath
-import akka.actor.Terminated
+import akka.pattern.AskTimeoutException
+import akka.util.Timeout
+
+import akka.pattern.{ ask, pipe }
+import akka.actor._
import akka.cluster.Cluster
import akka.cluster.ClusterEvent.ClusterDomainEvent
import akka.cluster.ClusterEvent.CurrentClusterState
@@ -24,6 +18,11 @@ import akka.cluster.ClusterEvent.MemberUp
import akka.cluster.Member
import akka.cluster.MemberStatus
+import scala.collection.immutable
+import scala.concurrent.duration._
+import scala.concurrent.Future
+import scala.reflect.ClassTag
+
/**
* @see [[ClusterSharding$ ClusterSharding extension]]
*/
@@ -167,19 +166,26 @@ object ShardRegion {
*/
def gracefulShutdownInstance = GracefulShutdown
- /*
+ sealed trait ShardRegionQuery
+
+ /**
* Send this message to the `ShardRegion` actor to request for [[CurrentRegions]],
* which contains the addresses of all registered regions.
- * Intended for testing purpose to see when cluster sharding is "ready".
+ * Intended for testing purpose to see when cluster sharding is "ready" or to monitor
+ * the state of the shard regions.
*/
- @SerialVersionUID(1L) final case object GetCurrentRegions extends ShardRegionCommand
+ @SerialVersionUID(1L) final case object GetCurrentRegions extends ShardRegionQuery
+ /**
+ * Java API:
+ */
def getCurrentRegionsInstance = GetCurrentRegions
/**
* Reply to `GetCurrentRegions`
*/
@SerialVersionUID(1L) final case class CurrentRegions(regions: Set[Address]) {
+
/**
* Java API
*/
@@ -190,6 +196,102 @@ object ShardRegion {
}
+ /**
+ * Send this message to the `ShardRegion` actor to request for [[ClusterShardingStats]],
+ * which contains statistics about the currently running sharded entities in the
+ * entire cluster. If the `timeout` is reached without answers from all shard regions
+ * the reply will contain an emmpty map of regions.
+ *
+ * Intended for testing purpose to see when cluster sharding is "ready" or to monitor
+ * the state of the shard regions.
+ */
+ @SerialVersionUID(1L) case class GetClusterShardingStats(timeout: FiniteDuration) extends ShardRegionQuery
+
+ /**
+ * Reply to [[GetClusterShardingStats]], contains statistics about all the sharding regions
+ * in the cluster.
+ */
+ @SerialVersionUID(1L) final case class ClusterShardingStats(regions: Map[Address, ShardRegionStats]) {
+
+ /**
+ * Java API
+ */
+ def getRegions(): java.util.Map[Address, ShardRegionStats] = {
+ import scala.collection.JavaConverters._
+ regions.asJava
+ }
+ }
+
+ /**
+ * Send this message to the `ShardRegion` actor to request for [[ShardRegionStats]],
+ * which contains statistics about the currently running sharded entities in the
+ * entire region.
+ * Intended for testing purpose to see when cluster sharding is "ready" or to monitor
+ * the state of the shard regions.
+ *
+ * For the statistics for the entire cluster, see [[GetClusterShardingStats$]].
+ */
+ @SerialVersionUID(1L) case object GetShardRegionStats extends ShardRegionQuery
+
+ /**
+ * Java API:
+ */
+ def getRegionStatsInstance = GetShardRegionStats
+
+ @SerialVersionUID(1L) final case class ShardRegionStats(stats: Map[ShardId, Int]) {
+
+ /**
+ * Java API
+ */
+ def getStats(): java.util.Map[ShardId, Int] = {
+ import scala.collection.JavaConverters._
+ stats.asJava
+ }
+
+ }
+
+ /**
+ * Send this message to a `ShardRegion` actor instance to request a
+ * [[CurrentShardRegionState]] which describes the current state of the region.
+ * The state contains information about what shards are running in this region
+ * and what entities are running on each of those shards.
+ */
+ @SerialVersionUID(1L) case object GetShardRegionState extends ShardRegionQuery
+
+ /**
+ * Java API:
+ */
+ def getShardRegionStateInstance = GetShardRegionState
+
+ /**
+ * Reply to [[GetShardRegionState$]]
+ *
+ * If gathering the shard information times out the set of shards will be empty.
+ */
+ @SerialVersionUID(1L) final case class CurrentShardRegionState(shards: Set[ShardState]) {
+
+ /**
+ * Java API:
+ *
+ * If gathering the shard information times out the set of shards will be empty.
+ */
+ def getShards(): java.util.Set[ShardState] = {
+ import scala.collection.JavaConverters._
+ shards.asJava
+ }
+ }
+
+ @SerialVersionUID(1L) final case class ShardState(shardId: ShardId, entityIds: Set[EntityId]) {
+
+ /**
+ * Java API:
+ */
+ def getEntityIds(): java.util.Set[EntityId] = {
+ import scala.collection.JavaConverters._
+ entityIds.asJava
+ }
+ }
+
private case object Retry extends ShardRegionCommand
/**
@@ -313,6 +415,7 @@ class ShardRegion(
case state: CurrentClusterState ⇒ receiveClusterState(state)
case msg: CoordinatorMessage ⇒ receiveCoordinatorMessage(msg)
case cmd: ShardRegionCommand ⇒ receiveCommand(cmd)
+ case query: ShardRegionQuery ⇒ receiveQuery(query)
case msg if extractEntityId.isDefinedAt(msg) ⇒ deliverMessage(msg, sender())
case msg: RestartShard ⇒ deliverMessage(msg, sender())
}
@@ -419,13 +522,26 @@ class ShardRegion(
gracefulShutdownInProgress = true
sendGracefulShutdownToCoordinator()
+ case _ ⇒ unhandled(cmd)
+ }
+
+ def receiveQuery(query: ShardRegionQuery): Unit = query match {
case GetCurrentRegions ⇒
coordinator match {
case Some(c) ⇒ c.forward(GetCurrentRegions)
case None ⇒ sender() ! CurrentRegions(Set.empty)
}
- case _ ⇒ unhandled(cmd)
+ case GetShardRegionState ⇒
+ replyToRegionStateQuery(sender())
+
+ case GetShardRegionStats ⇒
+ replyToRegionStatsQuery(sender())
+
+ case msg: GetClusterShardingStats ⇒
+ coordinator.fold(sender ! ClusterShardingStats(Map.empty))(_ forward msg)
+
+ case _ ⇒ unhandled(query)
}
def receiveTerminated(ref: ActorRef): Unit = {
@@ -458,6 +574,33 @@ class ShardRegion(
}
}
+ def replyToRegionStateQuery(ref: ActorRef): Unit = {
+ askAllShards[Shard.CurrentShardState](Shard.GetCurrentShardState).map { shardStates ⇒
+ CurrentShardRegionState(shardStates.map {
+ case (shardId, state) ⇒ ShardRegion.ShardState(shardId, state.entityIds)
+ }.toSet)
+ }.recover {
+ case x: AskTimeoutException ⇒ CurrentShardRegionState(Set.empty)
+ }.pipeTo(ref)
+ }
+
+ def replyToRegionStatsQuery(ref: ActorRef): Unit = {
+ askAllShards[Shard.ShardStats](Shard.GetShardStats).map { shardStats ⇒
+ ShardRegionStats(shardStats.map {
+ case (shardId, stats) ⇒ (shardId, stats.entityCount)
+ }.toMap)
+ }.recover {
+ case x: AskTimeoutException ⇒ ShardRegionStats(Map.empty)
+ }.pipeTo(ref)
+ }
+
+ def askAllShards[T: ClassTag](msg: Any): Future[Seq[(ShardId, T)]] = {
+ implicit val timeout: Timeout = 3.seconds
+ Future.sequence(shards.toSeq.map {
+ case (shardId, ref) ⇒ (ref ? msg).mapTo[T].map(t ⇒ (shardId, t))
+ })
+ }
+
def register(): Unit = {
coordinatorSelection.foreach(_ ! registrationMessage)
if (shardBuffers.nonEmpty && retryCount >= 5)
diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala
index 2bad30c918..be45ad81eb 100644
--- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala
+++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala
@@ -28,6 +28,7 @@ import akka.protobuf.MessageLite
private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSystem)
extends SerializerWithStringManifest with BaseSerializer {
import ShardCoordinator.Internal._
+ import Shard.{ GetShardStats, ShardStats }
import Shard.{ State ⇒ EntityState, EntityStarted, EntityStopped }
private lazy val serialization = SerializationExtension(system)
@@ -59,6 +60,9 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy
private val EntityStartedManifest = "CB"
private val EntityStoppedManifest = "CD"
+ private val GetShardStatsManifest = "DA"
+ private val ShardStatsManifest = "DB"
+
private val fromBinaryMap = collection.immutable.HashMap[String, Array[Byte] ⇒ AnyRef](
EntityStateManifest -> entityStateFromBinary,
EntityStartedManifest -> entityStartedFromBinary,
@@ -83,7 +87,10 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy
BeginHandOffAckManifest -> { bytes ⇒ BeginHandOffAck(shardIdMessageFromBinary(bytes)) },
HandOffManifest -> { bytes ⇒ HandOff(shardIdMessageFromBinary(bytes)) },
ShardStoppedManifest -> { bytes ⇒ ShardStopped(shardIdMessageFromBinary(bytes)) },
- GracefulShutdownReqManifest -> { bytes ⇒ GracefulShutdownReq(actorRefMessageFromBinary(bytes)) })
+ GracefulShutdownReqManifest -> { bytes ⇒ GracefulShutdownReq(actorRefMessageFromBinary(bytes)) },
+
+ GetShardStatsManifest -> { bytes ⇒ GetShardStats },
+ ShardStatsManifest -> { bytes ⇒ shardStatsFromBinary(bytes) })
override def manifest(obj: AnyRef): String = obj match {
case _: EntityState ⇒ EntityStateManifest
@@ -110,6 +117,9 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy
case _: HandOff ⇒ HandOffManifest
case _: ShardStopped ⇒ ShardStoppedManifest
case _: GracefulShutdownReq ⇒ GracefulShutdownReqManifest
+
+ case GetShardStats ⇒ GetShardStatsManifest
+ case _: ShardStats ⇒ ShardStatsManifest
case _ ⇒
throw new IllegalArgumentException(s"Can't serialize object of type ${obj.getClass} in [${getClass.getName}]")
}
@@ -140,6 +150,10 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy
case m: EntityState ⇒ entityStateToProto(m).toByteArray
case m: EntityStarted ⇒ entityStartedToProto(m).toByteArray
case m: EntityStopped ⇒ entityStoppedToProto(m).toByteArray
+
+ case GetShardStats ⇒ Array.emptyByteArray
+ case m: ShardStats ⇒ shardStatsToProto(m).toByteArray
+
case _ ⇒
throw new IllegalArgumentException(s"Can't serialize object of type ${obj.getClass} in [${getClass.getName}]")
}
@@ -245,6 +259,14 @@ private[akka] class ClusterShardingMessageSerializer(val system: ExtendedActorSy
private def entityStoppedFromBinary(bytes: Array[Byte]): EntityStopped =
EntityStopped(sm.EntityStopped.parseFrom(bytes).getEntityId)
+ private def shardStatsToProto(evt: ShardStats): sm.ShardStats =
+ sm.ShardStats.newBuilder().setShard(evt.shardId).setEntityCount(evt.entityCount).build()
+
+ private def shardStatsFromBinary(bytes: Array[Byte]): ShardStats = {
+ val parsed = sm.ShardStats.parseFrom(bytes)
+ ShardStats(parsed.getShard, parsed.getEntityCount)
+ }
+
private def resolveActorRef(path: String): ActorRef = {
system.provider.resolveActorRef(path)
}
diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStateSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStateSpec.scala
new file mode 100644
index 0000000000..c856c556f8
--- /dev/null
+++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStateSpec.scala
@@ -0,0 +1,185 @@
+/**
+ * Copyright (C) 2009-2015 Typesafe Inc.
+ */
+package akka.cluster.sharding
+
+import akka.actor._
+import akka.cluster.Cluster
+import akka.cluster.ClusterEvent.CurrentClusterState
+import akka.remote.testconductor.RoleName
+import akka.remote.testkit.{ MultiNodeConfig, MultiNodeSpec, STMultiNodeSpec }
+import akka.testkit.TestProbe
+import com.typesafe.config.ConfigFactory
+
+import scala.concurrent.duration._
+
+object ClusterShardingGetStateSpec {
+ case object Stop
+ case class Ping(id: Long)
+ case object Pong
+
+ class ShardedActor extends Actor with ActorLogging {
+ log.info(self.path.toString)
+ def receive = {
+ case Stop ⇒ context.stop(self)
+ case _: Ping ⇒ sender() ! Pong
+ }
+ }
+
+ val extractEntityId: ShardRegion.ExtractEntityId = {
+ case msg @ Ping(id) ⇒ (id.toString, msg)
+ }
+
+ val numberOfShards = 2
+
+ val extractShardId: ShardRegion.ExtractShardId = {
+ case Ping(id) ⇒ (id % numberOfShards).toString
+ }
+
+ val shardTypeName = "Ping"
+}
+
+object ClusterShardingGetStateSpecConfig extends MultiNodeConfig {
+ val controller = role("controller")
+ val first = role("first")
+ val second = role("second")
+
+ commonConfig(ConfigFactory.parseString("""
+ akka.loglevel = INFO
+ akka.actor.provider = "akka.cluster.ClusterActorRefProvider"
+ akka.remote.log-remote-lifecycle-events = off
+ akka.cluster.metrics.enabled = off
+ akka.cluster.auto-down-unreachable-after = 0s
+ akka.cluster.sharding {
+ coordinator-failure-backoff = 3s
+ shard-failure-backoff = 3s
+ state-store-mode = "ddata"
+ }
+ """))
+
+ nodeConfig(first, second)(ConfigFactory.parseString(
+ """akka.cluster.roles=["shard"]"""))
+
+}
+
+class ClusterShardingGetStateSpecMultiJvmNode1 extends ClusterShardingGetStateSpec
+class ClusterShardingGetStateSpecMultiJvmNode2 extends ClusterShardingGetStateSpec
+class ClusterShardingGetStateSpecMultiJvmNode3 extends ClusterShardingGetStateSpec
+
+abstract class ClusterShardingGetStateSpec extends MultiNodeSpec(ClusterShardingGetStateSpecConfig) with STMultiNodeSpec {
+
+ import ClusterShardingGetStateSpec._
+ import ClusterShardingGetStateSpecConfig._
+
+ def initialParticipants = roles.size
+
+ def startShard(): ActorRef = {
+ ClusterSharding(system).start(
+ typeName = shardTypeName,
+ entityProps = Props(new ShardedActor),
+ settings = ClusterShardingSettings(system).withRole("shard"),
+ extractEntityId = extractEntityId,
+ extractShardId = extractShardId)
+ }
+
+ def startProxy(): ActorRef = {
+ ClusterSharding(system).startProxy(
+ typeName = shardTypeName,
+ role = Some("shard"),
+ extractEntityId = extractEntityId,
+ extractShardId = extractShardId)
+ }
+
+ def join(from: RoleName): Unit = {
+ runOn(from) {
+ Cluster(system).join(node(controller).address)
+ }
+ enterBarrier(from.name + "-joined")
+ }
+
+ "Inspecting cluster sharding state" must {
+
+ "join cluster" in {
+ join(controller)
+ join(first)
+ join(second)
+
+ // make sure all nodes has joined
+ awaitAssert {
+ Cluster(system).sendCurrentClusterState(testActor)
+ expectMsgType[CurrentClusterState].members.size === 3
+ }
+
+ runOn(controller) {
+ startProxy()
+ }
+ runOn(first, second) {
+ startShard()
+ }
+
+ enterBarrier("sharding started")
+ }
+
+ "return empty state when no sharded actors has started" in {
+
+ awaitAssert {
+ val probe = TestProbe()
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+ region.tell(ShardRegion.GetCurrentRegions, probe.ref)
+ probe.expectMsgType[ShardRegion.CurrentRegions].regions.size === 0
+ }
+
+ enterBarrier("empty sharding")
+ }
+
+ "trigger sharded actors" in {
+ runOn(controller) {
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+
+ within(10.seconds) {
+ awaitAssert {
+ val pingProbe = TestProbe()
+ // trigger starting of 4 entities
+ (1 to 4).foreach(n ⇒ region.tell(Ping(n), pingProbe.ref))
+ pingProbe.receiveWhile(messages = 4) {
+ case Pong ⇒ ()
+ }
+ }
+ }
+ }
+
+ enterBarrier("sharded actors started")
+
+ }
+
+ "get shard state" in {
+ within(10.seconds) {
+ awaitAssert {
+ val probe = TestProbe()
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+ region.tell(ShardRegion.GetCurrentRegions, probe.ref)
+ val regions = probe.expectMsgType[ShardRegion.CurrentRegions].regions
+ regions.size === 2
+ regions.foreach { region ⇒
+ val path = RootActorPath(region) / "system" / "sharding" / shardTypeName
+
+ system.actorSelection(path).tell(ShardRegion.GetShardRegionState, probe.ref)
+ }
+ val states = probe.receiveWhile(messages = regions.size) {
+ case msg: ShardRegion.CurrentShardRegionState ⇒ msg
+ }
+ val allEntityIds = for {
+ state ← states
+ shard ← state.shards
+ entityId ← shard.entityIds
+ } yield entityId
+
+ allEntityIds.toSet === Set("1", "2", "3", "4")
+ }
+ }
+
+ enterBarrier("done")
+
+ }
+ }
+}
diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStatsSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStatsSpec.scala
new file mode 100644
index 0000000000..988dcf373c
--- /dev/null
+++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGetStatsSpec.scala
@@ -0,0 +1,179 @@
+/**
+ * Copyright (C) 2009-2015 Typesafe Inc.
+ */
+package akka.cluster.sharding
+
+import akka.actor._
+import akka.cluster.Cluster
+import akka.cluster.ClusterEvent.CurrentClusterState
+import akka.remote.testconductor.RoleName
+import akka.remote.testkit.{ MultiNodeConfig, MultiNodeSpec, STMultiNodeSpec }
+import akka.testkit.{ TestProbe, TestDuration }
+import com.typesafe.config.ConfigFactory
+
+import scala.concurrent.duration._
+
+object ClusterShardingGetStatsSpec {
+ case object Stop
+ case class Ping(id: Long)
+ case object Pong
+
+ class ShardedActor extends Actor with ActorLogging {
+ log.info(self.path.toString)
+ def receive = {
+ case Stop ⇒ context.stop(self)
+ case _: Ping ⇒ sender() ! Pong
+ }
+ }
+
+ val extractEntityId: ShardRegion.ExtractEntityId = {
+ case msg @ Ping(id) ⇒ (id.toString, msg)
+ }
+
+ val numberOfShards = 3
+
+ val extractShardId: ShardRegion.ExtractShardId = {
+ case Ping(id) ⇒ (id % numberOfShards).toString
+ }
+
+ val shardTypeName = "Ping"
+}
+
+object ClusterShardingGetStatsSpecConfig extends MultiNodeConfig {
+ val controller = role("controller")
+ val first = role("first")
+ val second = role("second")
+ val third = role("third")
+
+ commonConfig(ConfigFactory.parseString("""
+ akka.loglevel = INFO
+ akka.actor.provider = "akka.cluster.ClusterActorRefProvider"
+ akka.remote.log-remote-lifecycle-events = off
+ akka.cluster.metrics.enabled = off
+ akka.cluster.auto-down-unreachable-after = 0s
+ akka.cluster.sharding {
+ coordinator-failure-backoff = 3s
+ shard-failure-backoff = 3s
+ state-store-mode = "ddata"
+ }
+ """))
+
+ nodeConfig(first, second, third)(ConfigFactory.parseString(
+ """akka.cluster.roles=["shard"]"""))
+
+}
+
+class ClusterShardingGetStatsSpecMultiJvmNode1 extends ClusterShardingGetStatsSpec
+class ClusterShardingGetStatsSpecMultiJvmNode2 extends ClusterShardingGetStatsSpec
+class ClusterShardingGetStatsSpecMultiJvmNode3 extends ClusterShardingGetStatsSpec
+class ClusterShardingGetStatsSpecMultiJvmNode4 extends ClusterShardingGetStatsSpec
+
+abstract class ClusterShardingGetStatsSpec extends MultiNodeSpec(ClusterShardingGetStatsSpecConfig) with STMultiNodeSpec {
+
+ import ClusterShardingGetStatsSpec._
+ import ClusterShardingGetStatsSpecConfig._
+
+ def initialParticipants = roles.size
+
+ def startShard(): ActorRef = {
+ ClusterSharding(system).start(
+ typeName = shardTypeName,
+ entityProps = Props(new ShardedActor),
+ settings = ClusterShardingSettings(system).withRole("shard"),
+ extractEntityId = extractEntityId,
+ extractShardId = extractShardId)
+ }
+
+ def startProxy(): ActorRef = {
+ ClusterSharding(system).startProxy(
+ typeName = shardTypeName,
+ role = Some("shard"),
+ extractEntityId = extractEntityId,
+ extractShardId = extractShardId)
+ }
+
+ def join(from: RoleName): Unit = {
+ runOn(from) {
+ Cluster(system).join(node(controller).address)
+ }
+ enterBarrier(from.name + "-joined")
+ }
+
+ "Inspecting cluster sharding state" must {
+
+ "join cluster" in {
+ join(controller)
+ join(first)
+ join(second)
+ join(third)
+
+ // make sure all nodes has joined
+ awaitAssert {
+ Cluster(system).sendCurrentClusterState(testActor)
+ expectMsgType[CurrentClusterState].members.size === 3
+ }
+
+ runOn(controller) {
+ startProxy()
+ }
+ runOn(first, second, third) {
+ startShard()
+ }
+
+ enterBarrier("sharding started")
+ }
+
+ "return empty state when no sharded actors has started" in {
+
+ within(10.seconds) {
+ awaitAssert {
+ val probe = TestProbe()
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+ region.tell(ShardRegion.GetClusterShardingStats(10.seconds.dilated), probe.ref)
+ val shardStats = probe.expectMsgType[ShardRegion.ClusterShardingStats]
+ shardStats.regions.size shouldEqual 3
+ shardStats.regions.values.map(_.stats.size).sum shouldEqual 0
+ }
+ }
+
+ enterBarrier("empty sharding")
+ }
+
+ "trigger sharded actors" in {
+ runOn(controller) {
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+
+ within(10.seconds) {
+ awaitAssert {
+ val pingProbe = TestProbe()
+ // trigger starting of 2 entities on first and second node
+ // but leave third node without entities
+ (1 to 6).filterNot(_ % 3 == 0).foreach(n ⇒ region.tell(Ping(n), pingProbe.ref))
+ pingProbe.receiveWhile(messages = 4) {
+ case Pong ⇒ ()
+ }
+ }
+ }
+ }
+
+ enterBarrier("sharded actors started")
+
+ }
+
+ "get shard state" in {
+ within(10.seconds) {
+ awaitAssert {
+ val probe = TestProbe()
+ val region = ClusterSharding(system).shardRegion(shardTypeName)
+ region.tell(ShardRegion.GetClusterShardingStats(10.seconds.dilated), probe.ref)
+ val regions = probe.expectMsgType[ShardRegion.ClusterShardingStats].regions
+ regions.size shouldEqual 3
+ regions.values.flatMap(_.stats.values).sum shouldEqual 4
+ }
+ }
+
+ enterBarrier("done")
+
+ }
+ }
+}
diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala
index b9a9bc226b..d6ee7f85ca 100644
--- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala
+++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializerSpec.scala
@@ -69,5 +69,13 @@ class ClusterShardingMessageSerializerSpec extends AkkaSpec {
checkSerialization(Shard.EntityStarted("e1"))
checkSerialization(Shard.EntityStopped("e1"))
}
+
+ "be able to serializable GetShardStats" in {
+ checkSerialization(Shard.GetShardStats)
+ }
+
+ "be able to serializable ShardStats" in {
+ checkSerialization(Shard.ShardStats("a", 23))
+ }
}
}
diff --git a/akka-docs/rst/java/cluster-sharding.rst b/akka-docs/rst/java/cluster-sharding.rst
index 021396b448..3b70820e27 100644
--- a/akka-docs/rst/java/cluster-sharding.rst
+++ b/akka-docs/rst/java/cluster-sharding.rst
@@ -343,3 +343,18 @@ if needed.
Custom shard allocation strategy can be defined in an optional parameter to
``ClusterSharding.start``. See the API documentation of ``AbstractShardAllocationStrategy`` for details
of how to implement a custom shard allocation strategy.
+
+
+Inspecting cluster sharding state
+---------------------------------
+Two requests to inspect the cluster state are available:
+
+`ClusterShard.getShardRegionStateInstance` which will return a `ClusterShard.ShardRegionState` that contains
+the `ShardId`s running in a Region and what `EntityId`s are alive for each of them.
+
+`ClusterShard.getClusterShardingStatsInstance` which will query all the regions in the cluster and return
+a `ClusterShard.ClusterShardingStats` containing the `ShardId`s running in each region and a count
+of entities that are alive in each shard.
+
+The purpose of these messages is testing and monitoring, they are not provided to give access to
+directly sending messages to the individual entities.
\ No newline at end of file
diff --git a/akka-docs/rst/scala/cluster-sharding.rst b/akka-docs/rst/scala/cluster-sharding.rst
index 634f4eaa74..6fa8fede74 100644
--- a/akka-docs/rst/scala/cluster-sharding.rst
+++ b/akka-docs/rst/scala/cluster-sharding.rst
@@ -344,3 +344,18 @@ if needed.
Custom shard allocation strategy can be defined in an optional parameter to
``ClusterSharding.start``. See the API documentation of ``ShardAllocationStrategy`` for details of
how to implement a custom shard allocation strategy.
+
+
+Inspecting cluster sharding state
+---------------------------------
+Two requests to inspect the cluster state are available:
+
+`ClusterShard.GetShardRegionState` which will return a `ClusterShard.ShardRegionState` that contains
+the `ShardId`s running in a Region and what `EntityId`s are alive for each of them.
+
+`ClusterShard.GetClusterShardingStats` which will query all the regions in the cluster and return
+a `ClusterShard.ClusterShardingStats` containing the `ShardId`s running in each region and a count
+of entities that are alive in each shard.
+
+The purpose of these messages is testing and monitoring, they are not provided to give access to
+directly sending messages to the individual entities.
\ No newline at end of file