From 5545ff8fa7b0d288a06f82a017480682de7982be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Endre=20S=C3=A1ndor=20Varga?= Date: Wed, 24 Feb 2016 13:19:10 +0100 Subject: [PATCH] #19862 Token bucket should start full - also renamed and refactored various variables for better understanding - added comments to explain features --- .../stream/scaladsl/FlowThrottleSpec.scala | 72 +++++++++++++------ .../scala/akka/stream/impl/Throttle.scala | 68 ++++++++++++------ 2 files changed, 98 insertions(+), 42 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala index 933324ccd3..3368a9523e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowThrottleSpec.scala @@ -85,14 +85,20 @@ class FlowThrottleSpec extends AkkaSpec { val upstream = TestPublisher.probe[Int]() val downstream = TestSubscriber.probe[Int]() Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream)) + + // Exhaust bucket first + downstream.request(5) + (1 to 5) foreach upstream.sendNext + downstream.receiveWithin(300.millis, 5) should be(1 to 5) + downstream.request(1) - upstream.sendNext(1) + upstream.sendNext(6) downstream.expectNoMsg(100.millis) - downstream.expectNext(1) + downstream.expectNext(6) downstream.request(5) downstream.expectNoMsg(1200.millis) - for (i ← 2 to 6) upstream.sendNext(i) - downstream.receiveWithin(300.millis, 5) should be(2 to 6) + for (i ← 7 to 11) upstream.sendNext(i) + downstream.receiveWithin(300.millis, 5) should be(7 to 11) downstream.cancel() } @@ -100,21 +106,31 @@ class FlowThrottleSpec extends AkkaSpec { val upstream = TestPublisher.probe[Int]() val downstream = TestSubscriber.probe[Int]() Source.fromPublisher(upstream).throttle(1, 200.millis, 5, Shaping).runWith(Sink.fromSubscriber(downstream)) + + // Exhaust bucket first + downstream.request(5) + (1 to 5) foreach upstream.sendNext + downstream.receiveWithin(300.millis, 5) should be(1 to 5) + downstream.request(1) - upstream.sendNext(1) + upstream.sendNext(6) downstream.expectNoMsg(100.millis) - downstream.expectNext(1) + downstream.expectNext(6) downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards downstream.request(5) - for (i ← 2 to 4) upstream.sendNext(i) - downstream.receiveWithin(100.millis, 2) should be(Seq(2, 3)) + for (i ← 7 to 10) upstream.sendNext(i) + downstream.receiveWithin(100.millis, 2) should be(Seq(7, 8)) downstream.cancel() } "throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped { + Await.result( + Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.seq), + 2.seconds) should ===(1 to 5) // Burst is 5 so this will not fail + an[RateExceededException] shouldBe thrownBy { Await.result( - Source(1 to 5).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore), + Source(1 to 6).throttle(1, 200.millis, 5, Enforcing).runWith(Sink.ignore), 2.seconds) } } @@ -190,41 +206,57 @@ class FlowThrottleSpec extends AkkaSpec { val upstream = TestPublisher.probe[Int]() val downstream = TestSubscriber.probe[Int]() Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (_) ⇒ 1, Shaping).runWith(Sink.fromSubscriber(downstream)) + + // Exhaust bucket first + downstream.request(5) + (1 to 5) foreach upstream.sendNext + downstream.receiveWithin(300.millis, 5) should be(1 to 5) + downstream.request(1) - upstream.sendNext(1) + upstream.sendNext(6) downstream.expectNoMsg(100.millis) - downstream.expectNext(1) + downstream.expectNext(6) downstream.request(5) downstream.expectNoMsg(1200.millis) - for (i ← 2 to 6) upstream.sendNext(i) - downstream.receiveWithin(300.millis, 5) should be(2 to 6) + for (i ← 7 to 11) upstream.sendNext(i) + downstream.receiveWithin(300.millis, 5) should be(7 to 11) downstream.cancel() } "burst some elements if have enough time" in Utils.assertAllStagesStopped { val upstream = TestPublisher.probe[Int]() val downstream = TestSubscriber.probe[Int]() - Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) ⇒ if (e < 4) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream)) + Source.fromPublisher(upstream).throttle(2, 400.millis, 5, (e) ⇒ if (e < 9) 1 else 20, Shaping).runWith(Sink.fromSubscriber(downstream)) + + // Exhaust bucket first + downstream.request(5) + (1 to 5) foreach upstream.sendNext + downstream.receiveWithin(300.millis, 5) should be(1 to 5) + downstream.request(1) - upstream.sendNext(1) + upstream.sendNext(6) downstream.expectNoMsg(100.millis) - downstream.expectNext(1) + downstream.expectNext(6) downstream.expectNoMsg(500.millis) //wait to receive 2 in burst afterwards downstream.request(5) - for (i ← 2 to 4) upstream.sendNext(i) - downstream.receiveWithin(200.millis, 2) should be(Seq(2, 3)) + for (i ← 7 to 9) upstream.sendNext(i) + downstream.receiveWithin(200.millis, 2) should be(Seq(7, 8)) downstream.cancel() } "throw exception when exceeding throughtput in enforced mode" in Utils.assertAllStagesStopped { + Await.result( + Source(1 to 4).throttle(2, 200.millis, 10, identity, Enforcing).runWith(Sink.seq), + 2.seconds) should ===(1 to 4) // Burst is 10 so this will not fail + an[RateExceededException] shouldBe thrownBy { Await.result( - Source(1 to 5).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore), + Source(1 to 6).throttle(2, 200.millis, 0, identity, Enforcing).runWith(Sink.ignore), 2.seconds) } } - "properly combine shape and throttle modes" in Utils.assertAllStagesStopped { + "properly combine shape and enforce modes" in Utils.assertAllStagesStopped { Source(1 to 5).throttle(2, 200.millis, 0, identity, Shaping) .throttle(1, 100.millis, 5, Enforcing) .runWith(TestSink.probe[Int]) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala b/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala index e9ea2fcd70..b252e333fb 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala @@ -13,6 +13,30 @@ import scala.concurrent.duration.{ FiniteDuration, _ } /** * INTERNAL API */ +private[stream] object Throttle { + + val miniTokenBits = 30 + + private def tokenToMiniToken(e: Int): Long = e.toLong << Throttle.miniTokenBits +} + +/** + * INTERNAL API + */ +/* + * This class tracks a token bucket in an efficient way. + * + * For accuracy, instead of tracking integer tokens the implementation tracks "miniTokens" which are 1/2^30 fraction + * of a token. This allows us to track token replenish rate as miniTokens/nanosecond which allows us to use simple + * arithmetic without division and also less inaccuracy due to rounding on token count caculation. + * + * The replenish amount, and hence the current time is only queried if the bucket does not hold enough miniTokens, in + * other words, replenishing the bucket is *on-need*. In addition, to compensate scheduler inaccuracy, the implementation + * calculates the ideal "previous time" explicitly, not relying on the scheduler to tick at that time. This means that + * when the scheduler actually ticks, some time has been elapsed since the calculated ideal tick time, and those tokens + * are added to the bucket as any calculation is always relative to the ideal tick time. + * + */ private[stream] class Throttle[T](cost: Int, per: FiniteDuration, maximumBurst: Int, @@ -23,18 +47,18 @@ private[stream] class Throttle[T](cost: Int, require(per.toMillis > 0, "per time must be > 0") require(!(mode == ThrottleMode.Enforcing && maximumBurst < 0), "maximumBurst must be > 0 in Enforcing mode") + private val maximumBurstMiniTokens = Throttle.tokenToMiniToken(maximumBurst) + private val miniTokensPerNanos = (Throttle.tokenToMiniToken(cost).toDouble / per.toNanos).toLong + private val timerName: String = "ThrottleTimer" + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { var willStop = false - var lastTokens: Long = maximumBurst - var previousTime: Long = now() - - val speed = ((cost.toDouble / per.toNanos) * 1073741824).toLong - val timerName: String = "ThrottleTimer" + var previousMiniTokens: Long = maximumBurstMiniTokens + var previousNanos: Long = System.nanoTime() var currentElement: Option[T] = None setHandler(in, new InHandler { - val scaledMaximumBurst = scale(maximumBurst) override def onUpstreamFinish(): Unit = if (isAvailable(out) && isTimerActive(timerName)) willStop = true @@ -42,26 +66,29 @@ private[stream] class Throttle[T](cost: Int, override def onPush(): Unit = { val elem = grab(in) - val elementCost = scale(costCalculation(elem)) + val elementCostMiniTokens = Throttle.tokenToMiniToken(costCalculation(elem)) - if (lastTokens >= elementCost) { - lastTokens -= elementCost + if (previousMiniTokens >= elementCostMiniTokens) { + previousMiniTokens -= elementCostMiniTokens push(out, elem) } else { - val currentTime = now() - val currentTokens = Math.min((currentTime - previousTime) * speed + lastTokens, scaledMaximumBurst) - if (currentTokens < elementCost) + val currentNanos = System.nanoTime() + val currentMiniTokens = Math.min( + (currentNanos - previousNanos) * miniTokensPerNanos + previousMiniTokens, + maximumBurstMiniTokens) + + if (currentMiniTokens < elementCostMiniTokens) mode match { case Shaping ⇒ currentElement = Some(elem) - val waitTime = (elementCost - currentTokens) / speed - previousTime = currentTime + waitTime - scheduleOnce(timerName, waitTime.nanos) + val waitNanos = (elementCostMiniTokens - currentMiniTokens) / miniTokensPerNanos + previousNanos = currentNanos + waitNanos + scheduleOnce(timerName, waitNanos.nanos) case Enforcing ⇒ failStage(new RateExceededException("Maximum throttle throughput exceeded")) } else { - lastTokens = currentTokens - elementCost - previousTime = currentTime + previousMiniTokens = currentMiniTokens - elementCostMiniTokens + previousNanos = currentNanos push(out, elem) } } @@ -71,7 +98,7 @@ private[stream] class Throttle[T](cost: Int, override protected def onTimer(key: Any): Unit = { push(out, currentElement.get) currentElement = None - lastTokens = 0 + previousMiniTokens = 0 if (willStop) completeStage() } @@ -79,11 +106,8 @@ private[stream] class Throttle[T](cost: Int, override def onPull(): Unit = pull(in) }) - override def preStart(): Unit = previousTime = now() + override def preStart(): Unit = previousNanos = System.nanoTime() - private def now(): Long = System.nanoTime() - - private def scale(e: Int): Long = e.toLong << 30 } override def toString = "Throttle"