Merge pull request #1630 from gakesson/master

Improved semantics for BoundedBlockingQueue
This commit is contained in:
Patrik Nordwall 2013-09-11 23:15:06 -07:00
commit 3fa2aeae8d

View file

@ -10,7 +10,7 @@ import java.util.{ AbstractQueue, Queue, Collection, Iterator }
import annotation.tailrec
/**
* BoundedBlockingQueue wraps any Queue and turns the result into a BlockingQueue with a limited capacity
* BoundedBlockingQueue wraps any Queue and turns the result into a BlockingQueue with a limited capacity.
* @param maxCapacity - the maximum capacity of this Queue, needs to be > 0
* @param backing - the backing Queue
* @tparam E - The type of the contents of this Queue
@ -36,24 +36,36 @@ class BoundedBlockingQueue[E <: AnyRef](
def put(e: E) { //Blocks until not full
if (e eq null) throw new NullPointerException
lock.lock()
lock.lockInterruptibly()
try {
while (backing.size() == maxCapacity)
notFull.await()
require(backing.offer(e))
notEmpty.signal()
@tailrec def putElement() {
if (backing.size() < maxCapacity) {
require(backing.offer(e))
notEmpty.signal()
} else {
notFull.await()
putElement()
}
}
putElement()
} finally lock.unlock()
}
def take(): E = { //Blocks until not empty
lock.lockInterruptibly()
try {
while (backing.size() == 0)
notEmpty.await()
val e = backing.poll()
require(e ne null)
notFull.signal()
e
@tailrec def takeElement(): E = {
if (!backing.isEmpty()) {
val e = backing.poll()
require(e ne null)
notFull.signal()
e
} else {
notEmpty.await()
takeElement()
}
}
takeElement()
} finally lock.unlock()
}
@ -72,50 +84,34 @@ class BoundedBlockingQueue[E <: AnyRef](
def offer(e: E, timeout: Long, unit: TimeUnit): Boolean = { //Tries to do it within the timeout, return false if fail
if (e eq null) throw new NullPointerException
var nanos = unit.toNanos(timeout)
lock.lockInterruptibly()
try {
@tailrec def awaitNotFull(ns: Long): Boolean =
if (backing.size() == maxCapacity) {
if (ns > 0) awaitNotFull(notFull.awaitNanos(ns))
else false
} else true
if (awaitNotFull(nanos)) {
require(backing.offer(e)) //Should never fail
notEmpty.signal()
true
} else false
@tailrec def offerElement(remainingNanos: Long): Boolean = {
if (backing.size() < maxCapacity) {
require(backing.offer(e)) //Should never fail
notEmpty.signal()
true
} else if (remainingNanos <= 0) false
else offerElement(notFull.awaitNanos(remainingNanos))
}
offerElement(unit.toNanos(timeout))
} finally lock.unlock()
}
def poll(timeout: Long, unit: TimeUnit): E = { //Tries to do it within the timeout, returns null if fail
var nanos = unit.toNanos(timeout)
lock.lockInterruptibly()
try {
var result: E = null.asInstanceOf[E]
var hasResult = false
while (!hasResult) {
hasResult = backing.poll() match {
case null if nanos <= 0
result = null.asInstanceOf[E]
true
case null
try {
nanos = notEmpty.awaitNanos(nanos)
} catch {
case ie: InterruptedException
notEmpty.signal()
throw ie
}
false
case e
@tailrec def pollElement(remainingNanos: Long): E = {
backing.poll() match {
case null if remainingNanos <= 0 null.asInstanceOf[E]
case null pollElement(notEmpty.awaitNanos(remainingNanos))
case e {
notFull.signal()
result = e
true
e
}
}
}
result
pollElement(unit.toNanos(timeout))
} finally lock.unlock()
}
@ -124,7 +120,7 @@ class BoundedBlockingQueue[E <: AnyRef](
try {
backing.poll() match {
case null null.asInstanceOf[E]
case e
case e
notFull.signal()
e
}
@ -152,7 +148,7 @@ class BoundedBlockingQueue[E <: AnyRef](
lock.lock()
try {
backing.clear()
notFull.signal()
notFull.signalAll()
} finally lock.unlock()
}
@ -178,21 +174,22 @@ class BoundedBlockingQueue[E <: AnyRef](
def drainTo(c: Collection[_ >: E], maxElements: Int): Int = {
if (c eq null) throw new NullPointerException
if (c eq this) throw new IllegalArgumentException
if (c eq backing) throw new IllegalArgumentException
if (maxElements <= 0) 0
else {
lock.lock()
try {
@tailrec def drainOne(n: Int): Int =
@tailrec def drainOne(n: Int = 0): Int = {
if (n < maxElements) {
backing.poll() match {
case null n
case e c add e; drainOne(n + 1)
}
} else {
notFull.signal()
n
}
drainOne(0)
} else n
}
val n = drainOne()
if (n > 0) notFull.signalAll()
n
} finally lock.unlock()
}
}