diff --git a/akka-actor-tests/src/test/scala/akka/event/EventStreamSpec.scala b/akka-actor-tests/src/test/scala/akka/event/EventStreamSpec.scala index 5447a1eb74..745f4ca2b8 100644 --- a/akka-actor-tests/src/test/scala/akka/event/EventStreamSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/event/EventStreamSpec.scala @@ -56,9 +56,11 @@ object EventStreamSpec { trait T trait AT extends T + trait ATT extends AT trait BT extends T - class TA - class TAATBT extends TA with AT with BT + trait BTT extends BT + class CC + class CCATBT extends CC with ATT with BTT } @org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) @@ -151,75 +153,75 @@ class EventStreamSpec extends AkkaSpec(EventStreamSpec.config) { "manage sub-channels using classes and traits (update on subscribe)" in { val es = new EventStream(false) - val tm1 = new TA - val tm2 = new TAATBT + val tm1 = new CC + val tm2 = new CCATBT val a1, a2, a3, a4 = TestProbe() es.subscribe(a1.ref, classOf[AT]) must be === true es.subscribe(a2.ref, classOf[BT]) must be === true - es.subscribe(a3.ref, classOf[TA]) must be === true - es.subscribe(a4.ref, classOf[TAATBT]) must be === true + es.subscribe(a3.ref, classOf[CC]) must be === true + es.subscribe(a4.ref, classOf[CCATBT]) must be === true es.publish(tm1) es.publish(tm2) a1.expectMsgType[AT] must be === tm2 a2.expectMsgType[BT] must be === tm2 - a3.expectMsgType[TA] must be === tm1 - a3.expectMsgType[TA] must be === tm2 - a4.expectMsgType[TAATBT] must be === tm2 + a3.expectMsgType[CC] must be === tm1 + a3.expectMsgType[CC] must be === tm2 + a4.expectMsgType[CCATBT] must be === tm2 es.unsubscribe(a1.ref, classOf[AT]) must be === true es.unsubscribe(a2.ref, classOf[BT]) must be === true - es.unsubscribe(a3.ref, classOf[TA]) must be === true - es.unsubscribe(a4.ref, classOf[TAATBT]) must be === true + es.unsubscribe(a3.ref, classOf[CC]) must be === true + es.unsubscribe(a4.ref, classOf[CCATBT]) must be === true } "manage sub-channels using classes and traits (update on unsubscribe)" in { val es = new EventStream(false) - val tm1 = new TA - val tm2 = new TAATBT + val tm1 = new CC + val tm2 = new CCATBT val a1, a2, a3, a4 = TestProbe() es.subscribe(a1.ref, classOf[AT]) must be === true es.subscribe(a2.ref, classOf[BT]) must be === true - es.subscribe(a3.ref, classOf[TA]) must be === true - es.subscribe(a4.ref, classOf[TAATBT]) must be === true - es.unsubscribe(a3.ref, classOf[TA]) must be === true + es.subscribe(a3.ref, classOf[CC]) must be === true + es.subscribe(a4.ref, classOf[CCATBT]) must be === true + es.unsubscribe(a3.ref, classOf[CC]) must be === true es.publish(tm1) es.publish(tm2) a1.expectMsgType[AT] must be === tm2 a2.expectMsgType[BT] must be === tm2 a3.expectNoMsg(1 second) - a4.expectMsgType[TAATBT] must be === tm2 + a4.expectMsgType[CCATBT] must be === tm2 es.unsubscribe(a1.ref, classOf[AT]) must be === true es.unsubscribe(a2.ref, classOf[BT]) must be === true - es.unsubscribe(a4.ref, classOf[TAATBT]) must be === true + es.unsubscribe(a4.ref, classOf[CCATBT]) must be === true } "manage sub-channels using classes and traits (update on unsubscribe all)" in { val es = new EventStream(false) - val tm1 = new TA - val tm2 = new TAATBT + val tm1 = new CC + val tm2 = new CCATBT val a1, a2, a3, a4 = TestProbe() es.subscribe(a1.ref, classOf[AT]) must be === true es.subscribe(a2.ref, classOf[BT]) must be === true - es.subscribe(a3.ref, classOf[TA]) must be === true - es.subscribe(a4.ref, classOf[TAATBT]) must be === true + es.subscribe(a3.ref, classOf[CC]) must be === true + es.subscribe(a4.ref, classOf[CCATBT]) must be === true es.unsubscribe(a3.ref) es.publish(tm1) es.publish(tm2) a1.expectMsgType[AT] must be === tm2 a2.expectMsgType[BT] must be === tm2 a3.expectNoMsg(1 second) - a4.expectMsgType[TAATBT] must be === tm2 + a4.expectMsgType[CCATBT] must be === tm2 es.unsubscribe(a1.ref, classOf[AT]) must be === true es.unsubscribe(a2.ref, classOf[BT]) must be === true - es.unsubscribe(a4.ref, classOf[TAATBT]) must be === true + es.unsubscribe(a4.ref, classOf[CCATBT]) must be === true } "manage sub-channels using classes and traits (update on publish)" in { val es = new EventStream(false) - val tm1 = new TA - val tm2 = new TAATBT + val tm1 = new CC + val tm2 = new CCATBT val a1, a2 = TestProbe() es.subscribe(a1.ref, classOf[AT]) must be === true @@ -232,6 +234,44 @@ class EventStreamSpec extends AkkaSpec(EventStreamSpec.config) { es.unsubscribe(a2.ref, classOf[BT]) must be === true } + "manage sub-channels using classes and traits (unsubscribe classes used with trait)" in { + val es = new EventStream(false) + val tm1 = new CC + val tm2 = new CCATBT + val a1, a2, a3 = TestProbe() + + es.subscribe(a1.ref, classOf[AT]) must be === true + es.subscribe(a2.ref, classOf[BT]) must be === true + es.subscribe(a2.ref, classOf[CC]) must be === true + es.subscribe(a3.ref, classOf[CC]) must be === true + es.unsubscribe(a2.ref, classOf[CC]) must be === true + es.unsubscribe(a3.ref, classOf[CCATBT]) must be === true + es.publish(tm1) + es.publish(tm2) + a1.expectMsgType[AT] must be === tm2 + a2.expectMsgType[BT] must be === tm2 + a3.expectMsgType[CC] must be === tm1 + es.unsubscribe(a1.ref, classOf[AT]) must be === true + es.unsubscribe(a2.ref, classOf[BT]) must be === true + es.unsubscribe(a3.ref, classOf[CC]) must be === true + } + + "manage sub-channels using classes and traits (subscribe after publish)" in { + val es = new EventStream(false) + val tm1 = new CCATBT + val a1, a2 = TestProbe() + + es.subscribe(a1.ref, classOf[AT]) must be === true + es.publish(tm1) + a1.expectMsgType[AT] must be === tm1 + a2.expectNoMsg(1 second) + es.subscribe(a2.ref, classOf[BTT]) must be === true + es.publish(tm1) + a1.expectMsgType[AT] must be === tm1 + a2.expectMsgType[BTT] must be === tm1 + es.unsubscribe(a1.ref, classOf[AT]) must be === true + es.unsubscribe(a2.ref, classOf[BTT]) must be === true + } } private def verifyLevel(bus: LoggingBus, level: Logging.LogLevel) { diff --git a/akka-actor/src/main/scala/akka/event/EventBus.scala b/akka-actor/src/main/scala/akka/event/EventBus.scala index 63436723bd..cb83fbe806 100644 --- a/akka-actor/src/main/scala/akka/event/EventBus.scala +++ b/akka-actor/src/main/scala/akka/event/EventBus.scala @@ -137,30 +137,22 @@ trait SubchannelClassification { this: EventBus ⇒ def subscribe(subscriber: Subscriber, to: Classifier): Boolean = subscriptions.synchronized { val diff = subscriptions.addValue(to, subscriber) - if (diff.isEmpty) false - else { - cache ++= diff - true - } + addToCache(diff) + diff.nonEmpty } def unsubscribe(subscriber: Subscriber, from: Classifier): Boolean = subscriptions.synchronized { val diff = subscriptions.removeValue(from, subscriber) - if (diff.isEmpty) false - else { - removeFromCache(diff) - true - } + // removeValue(K, V) does not return the diff to remove from or add to the cache + // but instead the whole set of keys and values that should be updated in the cache + cache ++= diff + diff.nonEmpty } def unsubscribe(subscriber: Subscriber): Unit = subscriptions.synchronized { - val diff = subscriptions.removeValue(subscriber) - if (diff.nonEmpty) removeFromCache(diff) + removeFromCache(subscriptions.removeValue(subscriber)) } - private def removeFromCache(changes: Seq[(Classifier, Set[Subscriber])]): Unit = - cache ++= changes map { case (c, s) ⇒ (c, cache.getOrElse(c, Set[Subscriber]()) -- s) } - def publish(event: Event): Unit = { val c = classify(event) val recv = @@ -168,13 +160,22 @@ trait SubchannelClassification { this: EventBus ⇒ else subscriptions.synchronized { if (cache contains c) cache(c) else { - val diff = subscriptions.addKey(c) - cache ++= diff + addToCache(subscriptions.addKey(c)) cache(c) } } recv foreach (publish(event, _)) } + + private def removeFromCache(changes: Seq[(Classifier, Set[Subscriber])]): Unit = + cache = (cache /: changes) { + case (m, (c, cs)) ⇒ m.updated(c, m.getOrElse(c, Set.empty[Subscriber]) -- cs) + } + + private def addToCache(changes: Seq[(Classifier, Set[Subscriber])]): Unit = + cache = (cache /: changes) { + case (m, (c, cs)) ⇒ m.updated(c, m.getOrElse(c, Set.empty[Subscriber]) ++ cs) + } } /** diff --git a/akka-actor/src/main/scala/akka/util/SubclassifiedIndex.scala b/akka-actor/src/main/scala/akka/util/SubclassifiedIndex.scala index 45a70be27c..ae82da6407 100644 --- a/akka-actor/src/main/scala/akka/util/SubclassifiedIndex.scala +++ b/akka-actor/src/main/scala/akka/util/SubclassifiedIndex.scala @@ -17,9 +17,9 @@ trait Subclassification[K] { def isSubclass(x: K, y: K): Boolean } -object SubclassifiedIndex { +private[akka] object SubclassifiedIndex { - class Nonroot[K, V](val key: K, _values: Set[V])(implicit sc: Subclassification[K]) extends SubclassifiedIndex[K, V](_values) { + class Nonroot[K, V](override val root: SubclassifiedIndex[K, V], val key: K, _values: Set[V])(implicit sc: Subclassification[K]) extends SubclassifiedIndex[K, V](_values) { override def innerAddValue(key: K, value: V): Changes = { // break the recursion on super when key is found and transition to recursive add-to-set @@ -34,6 +34,7 @@ object SubclassifiedIndex { } else kids } + // this will return the keys and values to be removed from the cache override def innerRemoveValue(key: K, value: V): Changes = { // break the recursion on super when key is found and transition to recursive remove-from-set if (sc.isEqual(key, this.key)) removeValue(value) else super.innerRemoveValue(key, value) @@ -47,6 +48,9 @@ object SubclassifiedIndex { } else kids } + override def innerFindValues(key: K): Set[V] = + if (sc.isEqual(key, this.key)) values else super.innerFindValues(key) + override def toString = subkeys.mkString("Nonroot(" + key + ", " + values + ",\n", ",\n", ")") } @@ -66,7 +70,7 @@ object SubclassifiedIndex { * cache, e.g. HashMap, is faster than tree traversal which must use linear * scan at each level. Therefore, no value traversals are published. */ -class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: Subclassification[K]) { +private[akka] class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: Subclassification[K]) { import SubclassifiedIndex._ @@ -76,9 +80,13 @@ class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: def this()(implicit sc: Subclassification[K]) = this(Set.empty) + protected val root = this + /** * Add key to this index which inherits its value set from the most specific * super-class which is known. + * + * @return the diff that should be added to the cache */ def addKey(key: K): Changes = mergeChangesByKey(innerAddKey(key)) @@ -94,14 +102,15 @@ class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: } else Nil } if (!found) { - integrate(new Nonroot(key, values)) - Seq((key, values)) + integrate(new Nonroot(root, key, values)) :+ ((key, values)) } else ch } /** * Add value to all keys which are subclasses of the given key. If the key * is not known yet, it is inserted as if using addKey. + * + * @return the diff that should be added to the cache */ def addValue(key: K, value: V): Changes = mergeChangesByKey(innerAddValue(key, value)) @@ -115,17 +124,25 @@ class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: } if (!found) { val v = values + value - val n = new Nonroot(key, v) - integrate(n) - n.innerAddValue(key, value) :+ (key -> v) + val n = new Nonroot(root, key, v) + integrate(n) ++ n.innerAddValue(key, value) :+ (key -> v) } else ch } /** * Remove value from all keys which are subclasses of the given key. - * @return The keys and values that have been removed. + * + * @return the complete changes that should be updated in the cache */ - def removeValue(key: K, value: V): Changes = mergeChangesByKey(innerRemoveValue(key, value)) + def removeValue(key: K, value: V): Changes = + // the reason for not using the values in the returned diff is that we need to + // go through the whole tree to find all values for the "changed" keys in other + // parts of the tree as well, since new nodes might have been created + mergeChangesByKey(innerRemoveValue(key, value)) map { + case (k, _) ⇒ (k, findValues(k)) + } + + // this will return the keys and values to be removed from the cache protected def innerRemoveValue(key: K, value: V): Changes = { var found = false val ch = subkeys flatMap { n ⇒ @@ -135,32 +152,57 @@ class SubclassifiedIndex[K, V] private (private var values: Set[V])(implicit sc: } else Nil } if (!found) { - val n = new Nonroot(key, values) - integrate(n) - n.removeValue(value) + val n = new Nonroot(root, key, values) + integrate(n) ++ n.removeValue(value) } else ch } /** * Remove value from all keys in the index. + * + * @return the diff that should be removed from the cache */ def removeValue(value: V): Changes = mergeChangesByKey(subkeys flatMap (_ removeValue value)) + /** + * Find all values for a given key in the index. + */ + protected final def findValues(key: K): Set[V] = root.innerFindValues(key) + protected def innerFindValues(key: K): Set[V] = + (Set.empty[V] /: subkeys) { (s, n) ⇒ + if (sc.isSubclass(key, n.key)) + s ++ n.innerFindValues(key) + else + s + } + + /** + * Find all subkeys of a given key in the index excluding some subkeys. + */ + protected final def findSubKeysExcept(key: K, except: Vector[Nonroot[K, V]]): Set[K] = root.innerFindSubKeys(key, except) + protected def innerFindSubKeys(key: K, except: Vector[Nonroot[K, V]]): Set[K] = + (Set.empty[K] /: subkeys) { (s, n) ⇒ + if (sc.isEqual(key, n.key)) s + else n.innerFindSubKeys(key, except) ++ { + if (sc.isSubclass(n.key, key) && !except.exists(e ⇒ sc.isEqual(key, e.key))) + s + n.key + else + s + } + } + override def toString = subkeys.mkString("SubclassifiedIndex(" + values + ",\n", ",\n", ")") /** * Add new Nonroot below this node and check all existing nodes for subclass relationship. - * - * @return true if and only if the new node has received subkeys during this operation. + * Also needs to find subkeys in other parts of the tree to compensate for multiple inheritance. */ - private def integrate(n: Nonroot[K, V]) { + private def integrate(n: Nonroot[K, V]): Changes = { val (subsub, sub) = subkeys partition (k ⇒ sc.isSubclass(k.key, n.key)) - if (sub.size == subkeys.size) { - subkeys :+= n - } else { - n.subkeys = subsub - subkeys = sub :+ n - } + subkeys = sub :+ n + n.subkeys = if (subsub.nonEmpty) subsub else n.subkeys + n.subkeys ++= findSubKeysExcept(n.key, n.subkeys).map(k ⇒ new Nonroot(root, k, values)) + n.subkeys.map(n ⇒ (n.key, n.values.toSet)) } private def mergeChangesByKey(changes: Changes): Changes =