diff --git a/akka-actor/src/main/scala/akka/routing/ConsistentHash.scala b/akka-actor/src/main/scala/akka/routing/ConsistentHash.scala index fca0837662..79c31cda33 100644 --- a/akka-actor/src/main/scala/akka/routing/ConsistentHash.scala +++ b/akka-actor/src/main/scala/akka/routing/ConsistentHash.scala @@ -4,7 +4,8 @@ package akka.routing -import scala.collection.immutable.TreeMap +import scala.collection.immutable.SortedMap +import scala.reflect.ClassTag import java.util.Arrays /** @@ -17,16 +18,18 @@ import java.util.Arrays * hash, i.e. make sure it is different for different nodes. * */ -class ConsistentHash[T] private (nodes: Map[Int, T], virtualNodesFactor: Int) { +class ConsistentHash[T: ClassTag] private (nodes: SortedMap[Int, T], virtualNodesFactor: Int) { import ConsistentHash._ if (virtualNodesFactor < 1) throw new IllegalArgumentException("virtualNodesFactor must be >= 1") - // sorted hash values of the nodes - private val (nodeHashRing: Array[Int], nodeRing: Vector[T]) = { - val (nhr: IndexedSeq[Int], nr: IndexedSeq[AnyRef]) = nodes.toArray.sortBy(_._1).unzip - (nhr.toArray, Vector[T]() ++ nr) + // arrays for fast binary search and access + // nodeHashRing is the sorted hash values of the nodes + // nodeRing is the nodes sorted in the same order as nodeHashRing, i.e. same index + private val (nodeHashRing: Array[Int], nodeRing: Array[T]) = { + val (nhr: Seq[Int], nr: Seq[T]) = nodes.toSeq.unzip + (nhr.toArray, nr.toArray) } /** @@ -102,8 +105,8 @@ class ConsistentHash[T] private (nodes: Map[Int, T], virtualNodesFactor: Int) { } object ConsistentHash { - def apply[T](nodes: Iterable[T], virtualNodesFactor: Int) = { - new ConsistentHash(TreeMap.empty[Int, T] ++ + def apply[T: ClassTag](nodes: Iterable[T], virtualNodesFactor: Int): ConsistentHash[T] = { + new ConsistentHash(SortedMap.empty[Int, T] ++ (for (node ← nodes; vnode ← 1 to virtualNodesFactor) yield (nodeHashFor(node, vnode) -> node)), virtualNodesFactor) } @@ -112,9 +115,9 @@ object ConsistentHash { * Factory method to create a ConsistentHash * JAVA API */ - def create[T](nodes: java.lang.Iterable[T], virtualNodesFactor: Int) = { + def create[T](nodes: java.lang.Iterable[T], virtualNodesFactor: Int): ConsistentHash[T] = { import scala.collection.JavaConverters._ - apply(nodes.asScala, virtualNodesFactor) + apply(nodes.asScala, virtualNodesFactor)(ClassTag(classOf[Any].asInstanceOf[Class[T]])) } private def nodeHashFor(node: Any, vnode: Int): Int =