Optimizations and correctness fixes to the AffinityPool

Notably:
* Reimplementation of ImmutableIntMap for much faster performance
* Benchmark for ImmutableIntMap added to akka-bench-jmh
* Many small performance improvements to the impl of AffinityPool
* Correctness fixes for pool lifecycle management
This commit is contained in:
Viktor Klang 2017-07-20 11:03:08 +02:00 committed by Konrad `ktoso` Malawski
parent aec87a94c4
commit 912a6f33e9
3 changed files with 410 additions and 326 deletions

View file

@ -6,28 +6,107 @@ package akka.dispatch.affinity
import java.lang.invoke.MethodHandles import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType.methodType import java.lang.invoke.MethodType.methodType
import java.util
import java.util.Collections import java.util.Collections
import java.util.concurrent.TimeUnit.MICROSECONDS import java.util.concurrent.TimeUnit.MICROSECONDS
import java.util.concurrent._ import java.util.concurrent._
import java.util.concurrent.atomic.{ AtomicInteger, AtomicReference } import java.util.concurrent.atomic.{ AtomicInteger, AtomicReference }
import java.util.concurrent.locks.{ Lock, LockSupport, ReentrantLock } import java.util.concurrent.locks.LockSupport
import java.lang.Integer.reverseBytes
import akka.dispatch._ import akka.dispatch._
import akka.util.Helpers.Requiring import akka.util.Helpers.Requiring
import com.typesafe.config.Config import com.typesafe.config.Config
import scala.annotation.tailrec import akka.annotation.{ InternalApi, ApiMayChange }
import java.lang.Integer.reverseBytes import akka.event.Logging
import akka.util.{ ImmutableIntMap, OptionVal, ReentrantGuard }
import akka.annotation.InternalApi
import akka.annotation.ApiMayChange
import akka.util.ImmutableIntMap
import akka.util.OptionVal
import scala.annotation.{ tailrec, switch }
import scala.collection.mutable import scala.collection.mutable
import scala.util.control.NonFatal import scala.util.control.NonFatal
@InternalApi
@ApiMayChange
private[affinity] object AffinityPool {
type PoolState = Int
// PoolState: waiting to be initialized
final val Uninitialized = 0
// PoolState: currently in the process of initializing
final val Initializing = 1
// PoolState: accepts new tasks and processes tasks that are enqueued
final val Running = 2
// PoolState: does not accept new tasks, processes tasks that are in the queue
final val ShuttingDown = 3
// PoolState: does not accept new tasks, does not process tasks in queue
final val ShutDown = 4
// PoolState: all threads have been stopped, does not process tasks and does not accept new ones
final val Terminated = 5
// Method handle to JDK9+ onSpinWait method
private val onSpinWaitMethodHandle =
try
OptionVal.Some(MethodHandles.lookup.findStatic(classOf[Thread], "onSpinWait", methodType(classOf[Unit])))
catch {
case NonFatal(_) OptionVal.None
}
type IdleState = Int
// IdleState: Initial state
final val Initial = 0
// IdleState: Spinning
final val Spinning = 1
// IdleState: Yielding
final val Yielding = 2
// IdleState: Parking
final val Parking = 3
// Following are auxiliary class and trait definitions
private final class IdleStrategy(idleCpuLevel: Int) {
private[this] val maxSpins = 1100 * idleCpuLevel - 1000
private[this] val maxYields = 5 * idleCpuLevel
private[this] val minParkPeriodNs = 1
private[this] val maxParkPeriodNs = MICROSECONDS.toNanos(250 - ((80 * (idleCpuLevel - 1)) / 3))
private[this] var state: IdleState = Initial
private[this] var turns = 0L
private[this] var parkPeriodNs = 0L
@inline private[this] final def transitionTo(newState: IdleState): Unit = {
state = newState
turns = 0
}
def idle(): Unit = {
(state: @switch) match {
case Initial
transitionTo(Spinning)
case Spinning
onSpinWaitMethodHandle match {
case OptionVal.Some(m) m.invokeExact()
case OptionVal.None
}
turns += 1
if (turns > maxSpins)
transitionTo(Yielding)
case Yielding
turns += 1
if (turns > maxYields) {
parkPeriodNs = minParkPeriodNs
transitionTo(Parking)
} else Thread.`yield`()
case Parking
LockSupport.parkNanos(parkPeriodNs)
parkPeriodNs = Math.min(parkPeriodNs << 1, maxParkPeriodNs)
}
}
final def reset(): Unit = transitionTo(Initial)
}
private final class BoundedAffinityTaskQueue(capacity: Int) extends AbstractBoundedNodeQueue[Runnable](capacity)
}
/** /**
* An [[ExecutorService]] implementation which pins actor to particular threads * An [[ExecutorService]] implementation which pins actor to particular threads
* and guaranteed that an actor's [[Mailbox]] will e run on the thread it used * and guaranteed that an actor's [[Mailbox]] will e run on the thread it used
@ -39,92 +118,85 @@ import scala.util.control.NonFatal
@InternalApi @InternalApi
@ApiMayChange @ApiMayChange
private[akka] class AffinityPool( private[akka] class AffinityPool(
parallelism: Int, id: String,
affinityGroupSize: Int, parallelism: Int,
tf: ThreadFactory, affinityGroupSize: Int,
idleCpuLevel: Int, threadFactory: ThreadFactory,
fairDistributionThreshold: Int, idleCpuLevel: Int,
rejectionHandler: RejectionHandler) final val fairDistributionThreshold: Int,
rejectionHandler: RejectionHandler)
extends AbstractExecutorService { extends AbstractExecutorService {
if (parallelism <= 0) if (parallelism <= 0)
throw new IllegalArgumentException("Size of pool cannot be less or equal to 0") throw new IllegalArgumentException("Size of pool cannot be less or equal to 0")
import AffinityPool._
// Held while starting/shutting down workers/pool in order to make // Held while starting/shutting down workers/pool in order to make
// the operations linear and enforce atomicity. An example of that would be // the operations linear and enforce atomicity. An example of that would be
// adding a worker. We want the creation of the worker, addition // adding a worker. We want the creation of the worker, addition
// to the set and starting to worker to be an atomic action. Using // to the set and starting to worker to be an atomic action. Using
// a concurrent set would not give us that // a concurrent set would not give us that
private val bookKeepingLock = new ReentrantLock() private val bookKeepingLock = new ReentrantGuard()
// condition used for awaiting termination // condition used for awaiting termination
private val terminationCondition = bookKeepingLock.newCondition() private val terminationCondition = bookKeepingLock.newCondition()
// indicates the current state of the pool // indicates the current state of the pool
@volatile final private var poolState: PoolState = Running @volatile final private var poolState: PoolState = Uninitialized
private final val workQueues = Array.fill(parallelism)(new BoundedTaskQueue(affinityGroupSize)) private final val workQueues = Array.fill(parallelism)(new BoundedAffinityTaskQueue(affinityGroupSize))
private final val workers = mutable.Set[ThreadPoolWorker]() private final val workers = mutable.Set[AffinityPoolWorker]()
// a counter that gets incremented every time a task is queued
private val executionCounter: AtomicInteger = new AtomicInteger(0)
// maps a runnable to an index of a worker queue // maps a runnable to an index of a worker queue
private val runnableToWorkerQueueIndex = new AtomicReference(ImmutableIntMap.empty) private[this] final val hashCache = new AtomicReference(ImmutableIntMap.empty)
private def locked[T](l: Lock)(body: T) = {
l.lock()
try {
body
} finally {
l.unlock()
}
}
private def getQueueForRunnable(command: Runnable) = {
private def getQueueForRunnable(command: Runnable): BoundedAffinityTaskQueue = {
val runnableHash = command.hashCode() val runnableHash = command.hashCode()
def sbhash(i: Int) = reverseBytes(i * 0x9e3775cd) * 0x9e3775cd def indexFor(h: Int): Int =
Math.abs(reverseBytes(h * 0x9e3775cd) * 0x9e3775cd) % parallelism // In memory of Phil Bagwell
def getNext = executionCounter.incrementAndGet() % parallelism
def updateIfAbsentAndGetQueueIndex(
workerQueueIndex: AtomicReference[ImmutableIntMap],
runnableHash: Int, queueIndex: Int): Int = {
@tailrec
def updateIndex(): Unit = {
val prev = workerQueueIndex.get()
if (!runnableToWorkerQueueIndex.compareAndSet(prev, prev.updateIfAbsent(runnableHash, queueIndex))) {
updateIndex()
}
}
updateIndex()
workerQueueIndex.get().get(runnableHash) // can safely call get..
}
val workQueueIndex = val workQueueIndex =
if (fairDistributionThreshold == 0 || runnableToWorkerQueueIndex.get().size > fairDistributionThreshold) if (fairDistributionThreshold == 0)
Math.abs(sbhash(runnableHash)) % parallelism indexFor(runnableHash)
else else {
updateIfAbsentAndGetQueueIndex(runnableToWorkerQueueIndex, runnableHash, getNext) @tailrec
def cacheLookup(prev: ImmutableIntMap, hash: Int): Int = {
val existingIndex = prev.get(runnableHash)
if (existingIndex >= 0) existingIndex
else if (prev.size > fairDistributionThreshold) indexFor(hash)
else {
val index = prev.size % parallelism
if (hashCache.compareAndSet(prev, prev.updated(runnableHash, index)))
index // Successfully added key
else
cacheLookup(hashCache.get(), hash) // Try again
}
}
cacheLookup(hashCache.get(), runnableHash)
}
workQueues(workQueueIndex) workQueues(workQueueIndex)
} }
//fires up initial workers def start(): this.type =
locked(bookKeepingLock) { bookKeepingLock.withGuard {
workQueues.foreach(q addWorker(workers, q)) if (poolState == Uninitialized) {
} poolState = Initializing
workQueues.foreach(q addWorker(workers, q))
private def addWorker(workers: mutable.Set[ThreadPoolWorker], q: BoundedTaskQueue): Unit = { poolState = Running
locked(bookKeepingLock) { }
val worker = new ThreadPoolWorker(q, new IdleStrategy(idleCpuLevel)) this
workers.add(worker)
worker.startWorker()
} }
}
private def tryEnqueue(command: Runnable) = getQueueForRunnable(command).add(command) // WARNING: Only call while holding the bookKeepingLock
private def addWorker(workers: mutable.Set[AffinityPoolWorker], q: BoundedAffinityTaskQueue): Unit = {
val worker = new AffinityPoolWorker(q, new IdleStrategy(idleCpuLevel))
workers.add(worker)
worker.start()
}
/** /**
* Each worker should go through that method while terminating. * Each worker should go through that method while terminating.
@ -140,26 +212,24 @@ private[akka] class AffinityPool(
* own termination * own termination
* *
*/ */
private def onWorkerExit(w: ThreadPoolWorker, abruptTermination: Boolean): Unit = private def onWorkerExit(w: AffinityPoolWorker, abruptTermination: Boolean): Unit =
locked(bookKeepingLock) { bookKeepingLock.withGuard {
workers.remove(w) workers.remove(w)
if (workers.isEmpty && !abruptTermination && poolState >= ShuttingDown) { if (abruptTermination && poolState == Running)
addWorker(workers, w.q)
else if (workers.isEmpty && !abruptTermination && poolState >= ShuttingDown) {
poolState = ShutDown // transition to shutdown and try to transition to termination poolState = ShutDown // transition to shutdown and try to transition to termination
attemptPoolTermination() attemptPoolTermination()
} }
if (abruptTermination && poolState == Running)
addWorker(workers, w.q)
} }
override def execute(command: Runnable): Unit = { override def execute(command: Runnable): Unit = {
if (command == null) val queue = getQueueForRunnable(command) // Will throw NPE if command is null
throw new NullPointerException if (poolState >= ShuttingDown || !queue.add(command))
if (!(poolState == Running && tryEnqueue(command)))
rejectionHandler.reject(command, this) rejectionHandler.reject(command, this)
} }
override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = { override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = {
// recurse until pool is terminated or time out reached // recurse until pool is terminated or time out reached
@tailrec @tailrec
def awaitTermination(nanos: Long): Boolean = { def awaitTermination(nanos: Long): Boolean = {
@ -168,23 +238,21 @@ private[akka] class AffinityPool(
else awaitTermination(terminationCondition.awaitNanos(nanos)) else awaitTermination(terminationCondition.awaitNanos(nanos))
} }
locked(bookKeepingLock) { bookKeepingLock.withGuard {
// need to hold the lock to avoid monitor exception // need to hold the lock to avoid monitor exception
awaitTermination(unit.toNanos(timeout)) awaitTermination(unit.toNanos(timeout))
} }
} }
private def attemptPoolTermination() = // WARNING: Only call while holding the bookKeepingLock
locked(bookKeepingLock) { private def attemptPoolTermination(): Unit =
if (workers.isEmpty && poolState == ShutDown) { if (workers.isEmpty && poolState == ShutDown) {
poolState = Terminated poolState = Terminated
terminationCondition.signalAll() terminationCondition.signalAll()
}
} }
override def shutdownNow(): util.List[Runnable] = override def shutdownNow(): java.util.List[Runnable] =
locked(bookKeepingLock) { bookKeepingLock.withGuard {
poolState = ShutDown poolState = ShutDown
workers.foreach(_.stop()) workers.foreach(_.stop())
attemptPoolTermination() attemptPoolTermination()
@ -193,173 +261,78 @@ private[akka] class AffinityPool(
} }
override def shutdown(): Unit = override def shutdown(): Unit =
locked(bookKeepingLock) { bookKeepingLock.withGuard {
poolState = ShuttingDown poolState = ShuttingDown
// interrupts only idle workers.. so others can process their queues // interrupts only idle workers.. so others can process their queues
workers.foreach(_.stopIfIdle()) workers.foreach(_.stopIfIdle())
attemptPoolTermination() attemptPoolTermination()
} }
override def isShutdown: Boolean = poolState == ShutDown override def isShutdown: Boolean = poolState >= ShutDown
override def isTerminated: Boolean = poolState == Terminated override def isTerminated: Boolean = poolState == Terminated
// Following are auxiliary class and trait definitions override def toString: String =
s"${Logging.simpleName(this)}(id = $id, parallelism = $parallelism, affinityGroupSize = $affinityGroupSize, threadFactory = $threadFactory, idleCpuLevel = $idleCpuLevel, fairDistributionThreshold = $fairDistributionThreshold, rejectionHandler = $rejectionHandler)"
private sealed trait PoolState extends Ordered[PoolState] { private[this] final class AffinityPoolWorker( final val q: BoundedAffinityTaskQueue, final val idleStrategy: IdleStrategy) extends Runnable {
def order: Int final val thread: Thread = threadFactory.newThread(this)
override def compare(that: PoolState): Int = this.order compareTo that.order @volatile private[this] var executing: Boolean = false
}
// accepts new tasks and processes tasks that are enqueued final def start(): Unit =
private case object Running extends PoolState { if (thread eq null) throw new IllegalStateException(s"Was not able to allocate worker thread for ${AffinityPool.this}")
override val order: Int = 0 else thread.start()
}
// does not accept new tasks, processes tasks that are in the queue override final def run(): Unit = {
private case object ShuttingDown extends PoolState { // Returns true if it executed something, false otherwise
override def order: Int = 1 def executeNext(): Boolean = {
} val c = q.poll()
if (c ne null) {
// does not accept new tasks, does not process tasks in queue executing = true
private case object ShutDown extends PoolState { try
override def order: Int = 2 c.run()
} finally
executing = false
// all threads have been stopped, does not process tasks and does not accept new ones idleStrategy.reset()
private case object Terminated extends PoolState { true
override def order: Int = 3 } else {
} idleStrategy.idle() // if not wait for a bit
false
private final class IdleStrategy(val idleCpuLevel: Int) { }
private val maxSpins = 1100 * idleCpuLevel - 1000
private val maxYields = 5 * idleCpuLevel
private val minParkPeriodNs = 1
private val maxParkPeriodNs = MICROSECONDS.toNanos(280 - 30 * idleCpuLevel)
private sealed trait State
private case object NotIdle extends State
private case object Spinning extends State
private case object Yielding extends State
private case object Parking extends State
private var state: State = NotIdle
private var spins = 0L
private var yields = 0L
private var parkPeriodNs = 0L
private val onSpinWaitMethodHandle =
try
OptionVal.Some(MethodHandles.lookup.findStatic(classOf[Thread], "onSpinWait", methodType(classOf[Unit])))
catch {
case NonFatal(_) OptionVal.None
} }
def idle(): Unit = {
state match {
case NotIdle
state = Spinning
spins += 1
case Spinning
onSpinWaitMethodHandle match {
case OptionVal.Some(m) m.invokeExact()
case OptionVal.None
}
spins += 1
if (spins > maxSpins) {
state = Yielding
yields = 0
}
case Yielding
yields += 1
if (yields > maxYields) {
state = Parking
parkPeriodNs = minParkPeriodNs
} else Thread.`yield`()
case Parking
LockSupport.parkNanos(parkPeriodNs)
parkPeriodNs = Math.min(parkPeriodNs << 1, maxParkPeriodNs)
}
}
def reset(): Unit = {
spins = 0
yields = 0
state = NotIdle
}
}
private final class BoundedTaskQueue(capacity: Int) extends AbstractBoundedNodeQueue[Runnable](capacity)
private final class ThreadPoolWorker(val q: BoundedTaskQueue, val idleStrategy: IdleStrategy) extends Runnable {
private sealed trait WorkerState
private case object NotStarted extends WorkerState
private case object InExecution extends WorkerState
private case object Idle extends WorkerState
val thread: Thread = tf.newThread(this)
@volatile private var workerState: WorkerState = NotStarted
def startWorker(): Unit = {
workerState = Idle
thread.start()
}
private def runCommand(command: Runnable) = {
workerState = InExecution
try
command.run()
finally
workerState = Idle
}
override def run(): Unit = {
/** /**
* Determines whether the worker can keep running or not. * We keep running as long as we are Running
* In order to continue polling for tasks three conditions * or we're ShuttingDown but we still have tasks to execute,
* need to be satisfied: * and we're not interrupted.
*
* 1) pool state is less than Shutting down or queue
* is not empty (e.g pool state is ShuttingDown but there are still messages to process)
*
* 2) the thread backing up this worker has not been interrupted
*
* 3) We are not in ShutDown state (in which we should not be processing any enqueued tasks)
*/ */
def shouldKeepRunning = @tailrec def runLoop(): Unit =
(poolState < ShuttingDown || !q.isEmpty) && if (!Thread.interrupted()) {
!Thread.interrupted() && (poolState: @switch) match {
poolState != ShutDown case Uninitialized ()
case Initializing | Running
executeNext()
runLoop()
case ShuttingDown
if (executeNext()) runLoop()
else ()
case ShutDown | Terminated ()
}
}
var abruptTermination = true var abruptTermination = true
try { try {
while (shouldKeepRunning) { runLoop()
val c = q.poll()
if (c ne null) {
runCommand(c)
idleStrategy.reset()
} else // if not wait for a bit
idleStrategy.idle()
}
abruptTermination = false // if we have reached here, our termination is not due to an exception abruptTermination = false // if we have reached here, our termination is not due to an exception
} finally { } finally {
onWorkerExit(this, abruptTermination) onWorkerExit(this, abruptTermination)
} }
} }
def stop() = def stop(): Unit = if (!thread.isInterrupted) thread.interrupt()
if (!thread.isInterrupted && workerState != NotStarted)
thread.interrupt()
def stopIfIdle() = def stopIfIdle(): Unit = if (!executing) stop()
if (workerState == Idle)
stop()
} }
} }
/** /**
@ -370,7 +343,7 @@ private[akka] class AffinityPool(
private[akka] final class AffinityPoolConfigurator(config: Config, prerequisites: DispatcherPrerequisites) private[akka] final class AffinityPoolConfigurator(config: Config, prerequisites: DispatcherPrerequisites)
extends ExecutorServiceConfigurator(config, prerequisites) { extends ExecutorServiceConfigurator(config, prerequisites) {
private final val MaxfairDistributionThreshold = 2048 private final val MaxFairDistributionThreshold = 2048
private val poolSize = ThreadPoolConfig.scaledPoolSize( private val poolSize = ThreadPoolConfig.scaledPoolSize(
config.getInt("parallelism-min"), config.getInt("parallelism-min"),
@ -382,21 +355,21 @@ private[akka] final class AffinityPoolConfigurator(config: Config, prerequisites
1 <= level && level <= 10, "idle-cpu-level must be between 1 and 10") 1 <= level && level <= 10, "idle-cpu-level must be between 1 and 10")
private val fairDistributionThreshold = config.getInt("fair-work-distribution-threshold").requiring(thr private val fairDistributionThreshold = config.getInt("fair-work-distribution-threshold").requiring(thr
0 <= thr && thr <= MaxfairDistributionThreshold, s"idle-cpu-level must be between 1 and $MaxfairDistributionThreshold") 0 <= thr && thr <= MaxFairDistributionThreshold, s"fair-work-distribution-threshold must be between 0 and $MaxFairDistributionThreshold")
private val rejectionHandlerFCQN = config.getString("rejection-handler-factory") private val rejectionHandlerFCQN = config.getString("rejection-handler-factory")
private val rejectionHandlerFactory = prerequisites.dynamicAccess private val rejectionHandlerFactory = prerequisites.dynamicAccess
.createInstanceFor[RejectionHandlerFactory](rejectionHandlerFCQN, Nil).recover({ .createInstanceFor[RejectionHandlerFactory](rejectionHandlerFCQN, Nil).recover({
case exception throw new IllegalArgumentException( case exception throw new IllegalArgumentException(
s"Cannot instantiate RejectionHandlerFactory (rejection-handler-factory = $rejectionHandlerFCQN),make sure it has an accessible empty constructor", s"Cannot instantiate RejectionHandlerFactory (rejection-handler-factory = $rejectionHandlerFCQN), make sure it has an accessible empty constructor",
exception) exception)
}).get }).get
override def createExecutorServiceFactory(id: String, threadFactory: ThreadFactory): ExecutorServiceFactory = override def createExecutorServiceFactory(id: String, threadFactory: ThreadFactory): ExecutorServiceFactory =
new ExecutorServiceFactory { new ExecutorServiceFactory {
override def createExecutorService: ExecutorService = override def createExecutorService: ExecutorService =
new AffinityPool(poolSize, taskQueueSize, threadFactory, idleCpuLevel, fairDistributionThreshold, rejectionHandlerFactory.create()) new AffinityPool(id, poolSize, taskQueueSize, threadFactory, idleCpuLevel, fairDistributionThreshold, rejectionHandlerFactory.create()).start()
} }
} }
@ -416,7 +389,7 @@ trait RejectionHandlerFactory {
private[akka] final class DefaultRejectionHandlerFactory extends RejectionHandlerFactory { private[akka] final class DefaultRejectionHandlerFactory extends RejectionHandlerFactory {
private class DefaultRejectionHandler extends RejectionHandler { private class DefaultRejectionHandler extends RejectionHandler {
override def reject(command: Runnable, service: ExecutorService): Unit = override def reject(command: Runnable, service: ExecutorService): Unit =
throw new RejectedExecutionException(s"Task ${command.toString} rejected from ${service.toString}") throw new RejectedExecutionException(s"Task $command rejected from $service")
} }
override def create(): RejectionHandler = new DefaultRejectionHandler() override def create(): RejectionHandler = new DefaultRejectionHandler()
} }

View file

@ -3,139 +3,138 @@
*/ */
package akka.util package akka.util
import java.util.Arrays import java.util.Arrays
import akka.annotation.InternalApi import akka.annotation.InternalApi
import scala.annotation.tailrec import scala.annotation.tailrec
/** /**
* INTERNAL API * INTERNAL API
*/ */
@InternalApi private[akka] object ImmutableIntMap { @InternalApi private[akka] object ImmutableIntMap {
val empty: ImmutableIntMap = final val empty: ImmutableIntMap = new ImmutableIntMap(Array.emptyIntArray, 0)
new ImmutableIntMap(Array.emptyIntArray, Array.empty)
private final val MaxScanLength = 10
} }
/** /**
* INTERNAL API * INTERNAL API
* Specialized Map for primitive `Int` keys to avoid allocations (boxing). * Specialized Map for primitive `Int` keys and values to avoid allocations (boxing).
* Keys and values are backed by arrays and lookup is performed with binary * Keys and values are encoded consecutively in a single Int array and does copy-on-write with no
* search. It's intended for rather small (<1000) maps. * structural sharing, it's intended for rather small maps (<1000 elements).
*/ */
@InternalApi private[akka] final class ImmutableIntMap private ( @InternalApi private[akka] final class ImmutableIntMap private (private final val kvs: Array[Int], final val size: Int) {
private val keys: Array[Int], private val values: Array[Int]) {
final val size: Int = keys.length private[this] final def indexForKey(key: Int): Int = {
// Custom implementation of binary search since we encode key + value in consecutive indicies.
// We do the binary search on half the size of the array then project to the full size.
// >>> 1 for division by 2: https://research.googleblog.com/2006/06/extra-extra-read-all-about-it-nearly.html
@tailrec def find(lo: Int, hi: Int): Int =
if (lo <= hi) {
val lohi = lo + hi // Since we search in half the array we don't need to div by 2 to find the real index of key
val idx = lohi & ~1 // Since keys are in even slots, we get the key idx from lo+hi by removing the lowest bit if set (odd)
val k = kvs(idx)
if (k == key) idx
else if (k < key) find((lohi >>> 1) + 1, hi)
else /* if (k > key) */ find(lo, (lohi >>> 1) - 1)
} else ~(lo << 1) // same as -((lo*2)+1): Item should be placed, negated to indicate no match
find(0, size - 1)
}
/** /**
* Worst case `O(log n)`, allocation free. * Worst case `O(log n)`, allocation free.
* Will return Int.MinValue if not found, so beware of storing Int.MinValues
*/ */
def get(key: Int): Int = { final def get(key: Int): Int = {
val i = Arrays.binarySearch(keys, key) // same binary search as in `indexforKey` replicated here for performance reasons.
if (i >= 0) values(i) @tailrec def find(lo: Int, hi: Int): Int =
else Int.MinValue // cant use null, cant use OptionVal, other option is to throw an exception... if (lo <= hi) {
val lohi = lo + hi // Since we search in half the array we don't need to div by 2 to find the real index of key
val k = kvs(lohi & ~1) // Since keys are in even slots, we get the key idx from lo+hi by removing the lowest bit if set (odd)
if (k == key) kvs(lohi | 1) // lohi, if odd, already points to the value-index, if even, we set the lowest bit to add 1
else if (k < key) find((lohi >>> 1) + 1, hi)
else /* if (k > key) */ find(lo, (lohi >>> 1) - 1)
} else Int.MinValue
find(0, size - 1)
} }
/** /**
* Worst case `O(log n)`, allocation free. * Worst case `O(log n)`, allocation free.
*/ */
def contains(key: Int): Boolean = { final def contains(key: Int): Boolean = indexForKey(key) >= 0
Arrays.binarySearch(keys, key) >= 0
}
def updateIfAbsent(key: Int, value: Int): ImmutableIntMap = {
if (contains(key))
this
else
updated(key, value)
}
/** /**
* Worst case `O(log n)`, creates new `ImmutableIntMap` * Worst case `O(n)`, creates new `ImmutableIntMap`
* with copies of the internal arrays for the keys and * with the given key and value if that key is not yet present in the map.
* values.
*/ */
def updated(key: Int, value: Int): ImmutableIntMap = { final def updateIfAbsent(key: Int, value: Int): ImmutableIntMap =
if (size == 0) if (size > 0) {
new ImmutableIntMap(Array(key), Array(value)) val i = indexForKey(key)
else { if (i >= 0) this
val i = Arrays.binarySearch(keys, key) else insert(key, value, i)
} else new ImmutableIntMap(Array(key, value), 1)
/**
* Worst case `O(n)`, creates new `ImmutableIntMap`
* with the given key with the given value.
*/
final def updated(key: Int, value: Int): ImmutableIntMap =
if (size > 0) {
val i = indexForKey(key)
if (i >= 0) { if (i >= 0) {
// existing key, replace value val valueIndex = i + 1
val newValues = new Array[Int](values.length) if (kvs(valueIndex) != value)
System.arraycopy(values, 0, newValues, 0, values.length) update(value, valueIndex)
newValues(i) = value else
new ImmutableIntMap(keys, newValues) this // If no change no need to copy anything
} else { } else insert(key, value, i)
// insert the entry at the right position, and keep the arrays sorted } else new ImmutableIntMap(Array(key, value), 1)
val j = -(i + 1)
val newKeys = new Array[Int](size + 1)
System.arraycopy(keys, 0, newKeys, 0, j)
newKeys(j) = key
System.arraycopy(keys, j, newKeys, j + 1, keys.length - j)
val newValues = new Array[Int](size + 1) private[this] final def update(value: Int, valueIndex: Int): ImmutableIntMap = {
System.arraycopy(values, 0, newValues, 0, j) val newKvs = kvs.clone()
newValues(j) = value newKvs(valueIndex) = value
System.arraycopy(values, j, newValues, j + 1, values.length - j) new ImmutableIntMap(newKvs, size)
new ImmutableIntMap(newKeys, newValues)
}
}
} }
def remove(key: Int): ImmutableIntMap = { private[this] final def insert(key: Int, value: Int, index: Int): ImmutableIntMap = {
val i = Arrays.binarySearch(keys, key) val at = ~index // ~n == -(n + 1): insert the entry at the right positionkeep the array sorted
val newKvs = new Array[Int](kvs.length + 2)
System.arraycopy(kvs, 0, newKvs, 0, at)
newKvs(at) = key
newKvs(at + 1) = value
System.arraycopy(kvs, at, newKvs, at + 2, kvs.length - at)
new ImmutableIntMap(newKvs, size + 1)
}
/**
* Worst case `O(n)`, creates new `ImmutableIntMap`
* without the given key.
*/
final def remove(key: Int): ImmutableIntMap = {
val i = indexForKey(key)
if (i >= 0) { if (i >= 0) {
if (size == 1) if (size > 1) {
ImmutableIntMap.empty val newKvs = new Array[Int](kvs.length - 2)
else { System.arraycopy(kvs, 0, newKvs, 0, i)
val newKeys = new Array[Int](size - 1) System.arraycopy(kvs, i + 2, newKvs, i, kvs.length - i - 2)
System.arraycopy(keys, 0, newKeys, 0, i) new ImmutableIntMap(newKvs, size - 1)
System.arraycopy(keys, i + 1, newKeys, i, keys.length - i - 1) } else ImmutableIntMap.empty
} else this
val newValues = new Array[Int](size - 1)
System.arraycopy(values, 0, newValues, 0, i)
System.arraycopy(values, i + 1, newValues, i, values.length - i - 1)
new ImmutableIntMap(newKeys, newValues)
}
} else
this
} }
/** /**
* All keys * All keys
*/ */
def keysIterator: Iterator[Int] = final def keysIterator: Iterator[Int] =
keys.iterator if (size < 1) Iterator.empty
else Iterator.range(0, kvs.length - 1, 2).map(kvs.apply)
override def toString: String = override final def toString: String =
keysIterator.map(key s"$key -> ${get(key)}").mkString("ImmutableIntMap(", ", ", ")") if (size < 1) "ImmutableIntMap()"
else Iterator.range(0, kvs.length - 1, 2).map(i s"${kvs(i)} -> ${kvs(i + 1)}").mkString("ImmutableIntMap(", ", ", ")")
override def hashCode: Int = { override final def hashCode: Int = Arrays.hashCode(kvs)
var result = HashCode.SEED
result = HashCode.hash(result, keys)
result = HashCode.hash(result, values)
result
}
override def equals(obj: Any): Boolean = obj match { override final def equals(obj: Any): Boolean = obj match {
case other: ImmutableIntMap case other: ImmutableIntMap Arrays.equals(kvs, other.kvs)
if (other eq this) true case _ false
else if (size != other.size) false
else if (size == 0 && other.size == 0) true
else {
@tailrec def check(i: Int): Boolean = {
if (i < 0) true
else if (keys(i) == other.keys(i) && values(i) == other.values(i))
check(i - 1) // recur, next elem
else false
}
check(size - 1)
}
case _ false
} }
} }

View file

@ -0,0 +1,112 @@
/**
* Copyright (C) 2014-2017 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.util
import org.openjdk.jmh.annotations._
import java.util.concurrent.TimeUnit
import scala.annotation.tailrec
@State(Scope.Benchmark)
@BenchmarkMode(Array(Mode.Throughput))
@Fork(1)
@Threads(1)
@Warmup(iterations = 10, time = 5, timeUnit = TimeUnit.MICROSECONDS, batchSize = 1)
@Measurement(iterations = 10, time = 15, timeUnit = TimeUnit.MICROSECONDS, batchSize = 1)
class ImmutableIntMapBench {
@tailrec private[this] final def add(n: Int, c: ImmutableIntMap = ImmutableIntMap.empty): ImmutableIntMap =
if (n >= 0) add(n - 1, c.updated(n, n))
else c
@tailrec private[this] final def contains(n: Int, by: Int, to: Int, in: ImmutableIntMap, b: Boolean): Boolean =
if (n <= to) {
val result = in.contains(n)
contains(n + by, by, to, in, result)
} else b
@tailrec private[this] final def get(n: Int, by: Int, to: Int, in: ImmutableIntMap, b: Int): Int =
if (n <= to) {
val result = in.get(n)
get(n + by, by, to, in, result)
} else b
@tailrec private[this] final def hashCode(n: Int, in: ImmutableIntMap, b: Int): Int =
if (n >= 0) {
val result = in.hashCode
hashCode(n - 1, in, result)
} else b
@tailrec private[this] final def updateIfAbsent(n: Int, by: Int, to: Int, in: ImmutableIntMap): ImmutableIntMap =
if (n <= to) updateIfAbsent(n + by, by, to, in.updateIfAbsent(n, n))
else in
@tailrec private[this] final def getKey(iterations: Int, key: Int, from: ImmutableIntMap): ImmutableIntMap = {
if (iterations > 0 && key != Int.MinValue) {
val k = from.get(key)
getKey(iterations - 1, k, from)
} else from
}
val odd1000 = (0 to 1000).iterator.filter(_ % 2 == 1).foldLeft(ImmutableIntMap.empty)((l, i) => l.updated(i, i))
@Benchmark
@OperationsPerInvocation(1)
def add1(): ImmutableIntMap = add(1)
@Benchmark
@OperationsPerInvocation(10)
def add10(): ImmutableIntMap = add(10)
@Benchmark
@OperationsPerInvocation(100)
def add100(): ImmutableIntMap = add(100)
@Benchmark
@OperationsPerInvocation(1000)
def add1000(): ImmutableIntMap = add(1000)
@Benchmark
@OperationsPerInvocation(10000)
def add10000(): ImmutableIntMap = add(10000)
@Benchmark
@OperationsPerInvocation(500)
def contains(): Boolean = contains(n = 1, by = 2, to = odd1000.size, in = odd1000, b = false)
@Benchmark
@OperationsPerInvocation(500)
def notcontains(): Boolean = contains(n = 0, by = 2, to = odd1000.size, in = odd1000, b = false)
@Benchmark
@OperationsPerInvocation(500)
def get(): Int = get(n = 1, by = 2, to = odd1000.size, in = odd1000, b = Int.MinValue)
@Benchmark
@OperationsPerInvocation(500)
def notget(): Int = get(n = 0, by = 2, to = odd1000.size, in = odd1000, b = Int.MinValue)
@Benchmark
@OperationsPerInvocation(500)
def updateNotAbsent(): ImmutableIntMap = updateIfAbsent(n = 1, by = 2, to = odd1000.size, in = odd1000)
@Benchmark
@OperationsPerInvocation(500)
def updateAbsent(): ImmutableIntMap = updateIfAbsent(n = 0, by = 2, to = odd1000.size, in = odd1000)
@Benchmark
@OperationsPerInvocation(10000)
def hashcode(): Int = hashCode(10000, odd1000, 0)
@Benchmark
@OperationsPerInvocation(1000)
def getMidElement(): ImmutableIntMap = getKey(iterations = 1000, key = 249, from = odd1000)
@Benchmark
@OperationsPerInvocation(1000)
def getLoElement(): ImmutableIntMap = getKey(iterations = 1000, key = 1, from = odd1000)
@Benchmark
@OperationsPerInvocation(1000)
def getHiElement(): ImmutableIntMap = getKey(iterations = 1000, key = 999, from = odd1000)
}