#19862 Token bucket should start full

- also renamed and refactored various variables for better understanding
 - added comments to explain features
This commit is contained in:
Endre Sándor Varga 2016-02-24 13:19:10 +01:00
parent cb91c60266
commit 5545ff8fa7
2 changed files with 98 additions and 42 deletions

View file

@ -85,14 +85,20 @@ class FlowThrottleSpec extends AkkaSpec {
val upstream = TestPublisher.probe[Int]()
val downstream = TestSubscriber.probe[Int]()
Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream))
// Exhaust bucket first
downstream.request(5)
(1 to 5) foreach upstream.sendNext
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
downstream.request(1)
upstream.sendNext(1)
upstream.sendNext(6)
downstream.expectNoMsg(100.millis)
downstream.expectNext(1)
downstream.expectNext(6)
downstream.request(5)
downstream.expectNoMsg(1200.millis)
for (i 2 to 6) upstream.sendNext(i)
downstream.receiveWithin(300.millis, 5) should be(2 to 6)
for (i 7 to 11) upstream.sendNext(i)
downstream.receiveWithin(300.millis, 5) should be(7 to 11)
downstream.cancel()
}
@ -100,21 +106,31 @@ class FlowThrottleSpec extends AkkaSpec {
val upstream = TestPublisher.probe[Int]()
val downstream = TestSubscriber.probe[Int]()
Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream))
// Exhaust bucket first
downstream.request(5)
(1 to 5) foreach upstream.sendNext
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
downstream.request(1)
upstream.sendNext(1)
upstream.sendNext(6)
downstream.expectNoMsg(100.millis)
downstream.expectNext(1)
downstream.expectNext(6)
downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards
downstream.request(5)
for (i 2 to 4) upstream.sendNext(i)
downstream.receiveWithin(100.millis, 2) should be(Seq(2, 3))
for (i 7 to 10) upstream.sendNext(i)
downstream.receiveWithin(100.millis, 2) should be(Seq(7, 8))
downstream.cancel()
}
"throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped {
Await.result(
Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.seq),
2.seconds) should ===(1 to 5) // Burst is 5 so this will not fail
an[RateExceededException] shouldBe thrownBy {
Await.result(
Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore),
Source(1 to 6).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore),
2.seconds)
}
}
@ -190,41 +206,57 @@ class FlowThrottleSpec extends AkkaSpec {
val upstream = TestPublisher.probe[Int]()
val downstream = TestSubscriber.probe[Int]()
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (_) 1, Shaping).runWith(Sink.fromSubscriber(downstream))
// Exhaust bucket first
downstream.request(5)
(1 to 5) foreach upstream.sendNext
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
downstream.request(1)
upstream.sendNext(1)
upstream.sendNext(6)
downstream.expectNoMsg(100.millis)
downstream.expectNext(1)
downstream.expectNext(6)
downstream.request(5)
downstream.expectNoMsg(1200.millis)
for (i 2 to 6) upstream.sendNext(i)
downstream.receiveWithin(300.millis, 5) should be(2 to 6)
for (i 7 to 11) upstream.sendNext(i)
downstream.receiveWithin(300.millis, 5) should be(7 to 11)
downstream.cancel()
}
"burst some elements if have enough time" in Utils.assertAllStagesStopped {
val upstream = TestPublisher.probe[Int]()
val downstream = TestSubscriber.probe[Int]()
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) if (e < 4) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream))
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) if (e < 9) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream))
// Exhaust bucket first
downstream.request(5)
(1 to 5) foreach upstream.sendNext
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
downstream.request(1)
upstream.sendNext(1)
upstream.sendNext(6)
downstream.expectNoMsg(100.millis)
downstream.expectNext(1)
downstream.expectNext(6)
downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards
downstream.request(5)
for (i 2 to 4) upstream.sendNext(i)
downstream.receiveWithin(200.millis, 2) should be(Seq(2, 3))
for (i 7 to 9) upstream.sendNext(i)
downstream.receiveWithin(200.millis, 2) should be(Seq(7, 8))
downstream.cancel()
}
"throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped {
Await.result(
Source(1 to 4).throttle(2, 200.millis, 10, identity, Enforcing).runWith(Sink.seq),
2.seconds) should ===(1 to 4) // Burst is 10 so this will not fail
an[RateExceededException] shouldBe thrownBy {
Await.result(
Source(1 to 5).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore),
Source(1 to 6).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore),
2.seconds)
}
}
"properly combine shape and throttle modes" in Utils.assertAllStagesStopped {
"properly combine shape and enforce modes" in Utils.assertAllStagesStopped {
Source(1 to 5).throttle(2, 200.millis, 0, identity, Shaping)
.throttle(1, 100.millis, 5, Enforcing)
.runWith(TestSink.probe[Int])

View file

@ -13,6 +13,30 @@ import scala.concurrent.duration.{ FiniteDuration, _ }
/**
* INTERNAL API
*/
private[stream] object Throttle {
val miniTokenBits = 30
private def tokenToMiniToken(e: Int): Long = e.toLong << Throttle.miniTokenBits
}
/**
* INTERNAL API
*/
/*
* This class tracks a token bucket in an efficient way.
*
* For accuracy, instead of tracking integer tokens the implementation tracks "miniTokens" which are 1/2^30 fraction
* of a token. This allows us to track token replenish rate as miniTokens/nanosecond which allows us to use simple
* arithmetic without division and also less inaccuracy due to rounding on token count caculation.
*
* The replenish amount, and hence the current time is only queried if the bucket does not hold enough miniTokens, in
* other words, replenishing the bucket is *on-need*. In addition, to compensate scheduler inaccuracy, the implementation
* calculates the ideal "previous time" explicitly, not relying on the scheduler to tick at that time. This means that
* when the scheduler actually ticks, some time has been elapsed since the calculated ideal tick time, and those tokens
* are added to the bucket as any calculation is always relative to the ideal tick time.
*
*/
private[stream] class Throttle[T](cost: Int,
per: FiniteDuration,
maximumBurst: Int,
@ -23,18 +47,18 @@ private[stream] class Throttle[T](cost: Int,
require(per.toMillis > 0, "per time must be > 0")
require(!(mode == ThrottleMode.Enforcing && maximumBurst < 0), "maximumBurst must be > 0 in Enforcing mode")
private val maximumBurstMiniTokens = Throttle.tokenToMiniToken(maximumBurst)
private val miniTokensPerNanos = (Throttle.tokenToMiniToken(cost).toDouble / per.toNanos).toLong
private val timerName: String = "ThrottleTimer"
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
var willStop = false
var lastTokens: Long = maximumBurst
var previousTime: Long = now()
val speed = ((cost.toDouble / per.toNanos) * 1073741824).toLong
val timerName: String = "ThrottleTimer"
var previousMiniTokens: Long = maximumBurstMiniTokens
var previousNanos: Long = System.nanoTime()
var currentElement: Option[T] = None
setHandler(in, new InHandler {
val scaledMaximumBurst = scale(maximumBurst)
override def onUpstreamFinish(): Unit =
if (isAvailable(out) && isTimerActive(timerName)) willStop = true
@ -42,26 +66,29 @@ private[stream] class Throttle[T](cost: Int,
override def onPush(): Unit = {
val elem = grab(in)
val elementCost = scale(costCalculation(elem))
val elementCostMiniTokens = Throttle.tokenToMiniToken(costCalculation(elem))
if (lastTokens >= elementCost) {
lastTokens -= elementCost
if (previousMiniTokens >= elementCostMiniTokens) {
previousMiniTokens -= elementCostMiniTokens
push(out, elem)
} else {
val currentTime = now()
val currentTokens = Math.min((currentTime - previousTime) * speed + lastTokens, scaledMaximumBurst)
if (currentTokens < elementCost)
val currentNanos = System.nanoTime()
val currentMiniTokens = Math.min(
(currentNanos - previousNanos) * miniTokensPerNanos + previousMiniTokens,
maximumBurstMiniTokens)
if (currentMiniTokens < elementCostMiniTokens)
mode match {
case Shaping
currentElement = Some(elem)
val waitTime = (elementCost - currentTokens) / speed
previousTime = currentTime + waitTime
scheduleOnce(timerName, waitTime.nanos)
val waitNanos = (elementCostMiniTokens - currentMiniTokens) / miniTokensPerNanos
previousNanos = currentNanos + waitNanos
scheduleOnce(timerName, waitNanos.nanos)
case Enforcing failStage(new RateExceededException("Maximum throttle throughput exceeded"))
}
else {
lastTokens = currentTokens - elementCost
previousTime = currentTime
previousMiniTokens = currentMiniTokens - elementCostMiniTokens
previousNanos = currentNanos
push(out, elem)
}
}
@ -71,7 +98,7 @@ private[stream] class Throttle[T](cost: Int,
override protected def onTimer(key: Any): Unit = {
push(out, currentElement.get)
currentElement = None
lastTokens = 0
previousMiniTokens = 0
if (willStop) completeStage()
}
@ -79,11 +106,8 @@ private[stream] class Throttle[T](cost: Int,
override def onPull(): Unit = pull(in)
})
override def preStart(): Unit = previousTime = now()
override def preStart(): Unit = previousNanos = System.nanoTime()
private def now(): Long = System.nanoTime()
private def scale(e: Int): Long = e.toLong << 30
}
override def toString = "Throttle"