#19862 Token bucket should start full
- also renamed and refactored various variables for better understanding - added comments to explain features
This commit is contained in:
parent
cb91c60266
commit
5545ff8fa7
2 changed files with 98 additions and 42 deletions
|
|
@ -85,14 +85,20 @@ class FlowThrottleSpec extends AkkaSpec {
|
|||
val upstream = TestPublisher.probe[Int]()
|
||||
val downstream = TestSubscriber.probe[Int]()
|
||||
Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream))
|
||||
|
||||
// Exhaust bucket first
|
||||
downstream.request(5)
|
||||
(1 to 5) foreach upstream.sendNext
|
||||
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
|
||||
|
||||
downstream.request(1)
|
||||
upstream.sendNext(1)
|
||||
upstream.sendNext(6)
|
||||
downstream.expectNoMsg(100.millis)
|
||||
downstream.expectNext(1)
|
||||
downstream.expectNext(6)
|
||||
downstream.request(5)
|
||||
downstream.expectNoMsg(1200.millis)
|
||||
for (i ← 2 to 6) upstream.sendNext(i)
|
||||
downstream.receiveWithin(300.millis, 5) should be(2 to 6)
|
||||
for (i ← 7 to 11) upstream.sendNext(i)
|
||||
downstream.receiveWithin(300.millis, 5) should be(7 to 11)
|
||||
downstream.cancel()
|
||||
}
|
||||
|
||||
|
|
@ -100,21 +106,31 @@ class FlowThrottleSpec extends AkkaSpec {
|
|||
val upstream = TestPublisher.probe[Int]()
|
||||
val downstream = TestSubscriber.probe[Int]()
|
||||
Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream))
|
||||
|
||||
// Exhaust bucket first
|
||||
downstream.request(5)
|
||||
(1 to 5) foreach upstream.sendNext
|
||||
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
|
||||
|
||||
downstream.request(1)
|
||||
upstream.sendNext(1)
|
||||
upstream.sendNext(6)
|
||||
downstream.expectNoMsg(100.millis)
|
||||
downstream.expectNext(1)
|
||||
downstream.expectNext(6)
|
||||
downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards
|
||||
downstream.request(5)
|
||||
for (i ← 2 to 4) upstream.sendNext(i)
|
||||
downstream.receiveWithin(100.millis, 2) should be(Seq(2, 3))
|
||||
for (i ← 7 to 10) upstream.sendNext(i)
|
||||
downstream.receiveWithin(100.millis, 2) should be(Seq(7, 8))
|
||||
downstream.cancel()
|
||||
}
|
||||
|
||||
"throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped {
|
||||
Await.result(
|
||||
Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.seq),
|
||||
2.seconds) should ===(1 to 5) // Burst is 5 so this will not fail
|
||||
|
||||
an[RateExceededException] shouldBe thrownBy {
|
||||
Await.result(
|
||||
Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore),
|
||||
Source(1 to 6).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore),
|
||||
2.seconds)
|
||||
}
|
||||
}
|
||||
|
|
@ -190,41 +206,57 @@ class FlowThrottleSpec extends AkkaSpec {
|
|||
val upstream = TestPublisher.probe[Int]()
|
||||
val downstream = TestSubscriber.probe[Int]()
|
||||
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (_) ⇒ 1, Shaping).runWith(Sink.fromSubscriber(downstream))
|
||||
|
||||
// Exhaust bucket first
|
||||
downstream.request(5)
|
||||
(1 to 5) foreach upstream.sendNext
|
||||
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
|
||||
|
||||
downstream.request(1)
|
||||
upstream.sendNext(1)
|
||||
upstream.sendNext(6)
|
||||
downstream.expectNoMsg(100.millis)
|
||||
downstream.expectNext(1)
|
||||
downstream.expectNext(6)
|
||||
downstream.request(5)
|
||||
downstream.expectNoMsg(1200.millis)
|
||||
for (i ← 2 to 6) upstream.sendNext(i)
|
||||
downstream.receiveWithin(300.millis, 5) should be(2 to 6)
|
||||
for (i ← 7 to 11) upstream.sendNext(i)
|
||||
downstream.receiveWithin(300.millis, 5) should be(7 to 11)
|
||||
downstream.cancel()
|
||||
}
|
||||
|
||||
"burst some elements if have enough time" in Utils.assertAllStagesStopped {
|
||||
val upstream = TestPublisher.probe[Int]()
|
||||
val downstream = TestSubscriber.probe[Int]()
|
||||
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) ⇒ if (e < 4) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream))
|
||||
Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) ⇒ if (e < 9) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream))
|
||||
|
||||
// Exhaust bucket first
|
||||
downstream.request(5)
|
||||
(1 to 5) foreach upstream.sendNext
|
||||
downstream.receiveWithin(300.millis, 5) should be(1 to 5)
|
||||
|
||||
downstream.request(1)
|
||||
upstream.sendNext(1)
|
||||
upstream.sendNext(6)
|
||||
downstream.expectNoMsg(100.millis)
|
||||
downstream.expectNext(1)
|
||||
downstream.expectNext(6)
|
||||
downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards
|
||||
downstream.request(5)
|
||||
for (i ← 2 to 4) upstream.sendNext(i)
|
||||
downstream.receiveWithin(200.millis, 2) should be(Seq(2, 3))
|
||||
for (i ← 7 to 9) upstream.sendNext(i)
|
||||
downstream.receiveWithin(200.millis, 2) should be(Seq(7, 8))
|
||||
downstream.cancel()
|
||||
}
|
||||
|
||||
"throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped {
|
||||
Await.result(
|
||||
Source(1 to 4).throttle(2, 200.millis, 10, identity, Enforcing).runWith(Sink.seq),
|
||||
2.seconds) should ===(1 to 4) // Burst is 10 so this will not fail
|
||||
|
||||
an[RateExceededException] shouldBe thrownBy {
|
||||
Await.result(
|
||||
Source(1 to 5).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore),
|
||||
Source(1 to 6).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore),
|
||||
2.seconds)
|
||||
}
|
||||
}
|
||||
|
||||
"properly combine shape and throttle modes" in Utils.assertAllStagesStopped {
|
||||
"properly combine shape and enforce modes" in Utils.assertAllStagesStopped {
|
||||
Source(1 to 5).throttle(2, 200.millis, 0, identity, Shaping)
|
||||
.throttle(1, 100.millis, 5, Enforcing)
|
||||
.runWith(TestSink.probe[Int])
|
||||
|
|
|
|||
|
|
@ -13,6 +13,30 @@ import scala.concurrent.duration.{ FiniteDuration, _ }
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[stream] object Throttle {
|
||||
|
||||
val miniTokenBits = 30
|
||||
|
||||
private def tokenToMiniToken(e: Int): Long = e.toLong << Throttle.miniTokenBits
|
||||
}
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
/*
|
||||
* This class tracks a token bucket in an efficient way.
|
||||
*
|
||||
* For accuracy, instead of tracking integer tokens the implementation tracks "miniTokens" which are 1/2^30 fraction
|
||||
* of a token. This allows us to track token replenish rate as miniTokens/nanosecond which allows us to use simple
|
||||
* arithmetic without division and also less inaccuracy due to rounding on token count caculation.
|
||||
*
|
||||
* The replenish amount, and hence the current time is only queried if the bucket does not hold enough miniTokens, in
|
||||
* other words, replenishing the bucket is *on-need*. In addition, to compensate scheduler inaccuracy, the implementation
|
||||
* calculates the ideal "previous time" explicitly, not relying on the scheduler to tick at that time. This means that
|
||||
* when the scheduler actually ticks, some time has been elapsed since the calculated ideal tick time, and those tokens
|
||||
* are added to the bucket as any calculation is always relative to the ideal tick time.
|
||||
*
|
||||
*/
|
||||
private[stream] class Throttle[T](cost: Int,
|
||||
per: FiniteDuration,
|
||||
maximumBurst: Int,
|
||||
|
|
@ -23,18 +47,18 @@ private[stream] class Throttle[T](cost: Int,
|
|||
require(per.toMillis > 0, "per time must be > 0")
|
||||
require(!(mode == ThrottleMode.Enforcing && maximumBurst < 0), "maximumBurst must be > 0 in Enforcing mode")
|
||||
|
||||
private val maximumBurstMiniTokens = Throttle.tokenToMiniToken(maximumBurst)
|
||||
private val miniTokensPerNanos = (Throttle.tokenToMiniToken(cost).toDouble / per.toNanos).toLong
|
||||
private val timerName: String = "ThrottleTimer"
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
||||
var willStop = false
|
||||
var lastTokens: Long = maximumBurst
|
||||
var previousTime: Long = now()
|
||||
|
||||
val speed = ((cost.toDouble / per.toNanos) * 1073741824).toLong
|
||||
val timerName: String = "ThrottleTimer"
|
||||
var previousMiniTokens: Long = maximumBurstMiniTokens
|
||||
var previousNanos: Long = System.nanoTime()
|
||||
|
||||
var currentElement: Option[T] = None
|
||||
|
||||
setHandler(in, new InHandler {
|
||||
val scaledMaximumBurst = scale(maximumBurst)
|
||||
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (isAvailable(out) && isTimerActive(timerName)) willStop = true
|
||||
|
|
@ -42,26 +66,29 @@ private[stream] class Throttle[T](cost: Int,
|
|||
|
||||
override def onPush(): Unit = {
|
||||
val elem = grab(in)
|
||||
val elementCost = scale(costCalculation(elem))
|
||||
val elementCostMiniTokens = Throttle.tokenToMiniToken(costCalculation(elem))
|
||||
|
||||
if (lastTokens >= elementCost) {
|
||||
lastTokens -= elementCost
|
||||
if (previousMiniTokens >= elementCostMiniTokens) {
|
||||
previousMiniTokens -= elementCostMiniTokens
|
||||
push(out, elem)
|
||||
} else {
|
||||
val currentTime = now()
|
||||
val currentTokens = Math.min((currentTime - previousTime) * speed + lastTokens, scaledMaximumBurst)
|
||||
if (currentTokens < elementCost)
|
||||
val currentNanos = System.nanoTime()
|
||||
val currentMiniTokens = Math.min(
|
||||
(currentNanos - previousNanos) * miniTokensPerNanos + previousMiniTokens,
|
||||
maximumBurstMiniTokens)
|
||||
|
||||
if (currentMiniTokens < elementCostMiniTokens)
|
||||
mode match {
|
||||
case Shaping ⇒
|
||||
currentElement = Some(elem)
|
||||
val waitTime = (elementCost - currentTokens) / speed
|
||||
previousTime = currentTime + waitTime
|
||||
scheduleOnce(timerName, waitTime.nanos)
|
||||
val waitNanos = (elementCostMiniTokens - currentMiniTokens) / miniTokensPerNanos
|
||||
previousNanos = currentNanos + waitNanos
|
||||
scheduleOnce(timerName, waitNanos.nanos)
|
||||
case Enforcing ⇒ failStage(new RateExceededException("Maximum throttle throughput exceeded"))
|
||||
}
|
||||
else {
|
||||
lastTokens = currentTokens - elementCost
|
||||
previousTime = currentTime
|
||||
previousMiniTokens = currentMiniTokens - elementCostMiniTokens
|
||||
previousNanos = currentNanos
|
||||
push(out, elem)
|
||||
}
|
||||
}
|
||||
|
|
@ -71,7 +98,7 @@ private[stream] class Throttle[T](cost: Int,
|
|||
override protected def onTimer(key: Any): Unit = {
|
||||
push(out, currentElement.get)
|
||||
currentElement = None
|
||||
lastTokens = 0
|
||||
previousMiniTokens = 0
|
||||
if (willStop) completeStage()
|
||||
}
|
||||
|
||||
|
|
@ -79,11 +106,8 @@ private[stream] class Throttle[T](cost: Int,
|
|||
override def onPull(): Unit = pull(in)
|
||||
})
|
||||
|
||||
override def preStart(): Unit = previousTime = now()
|
||||
override def preStart(): Unit = previousNanos = System.nanoTime()
|
||||
|
||||
private def now(): Long = System.nanoTime()
|
||||
|
||||
private def scale(e: Int): Long = e.toLong << 30
|
||||
}
|
||||
|
||||
override def toString = "Throttle"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue