pekko/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala

88 lines
2.9 KiB
Scala
Raw Normal View History

2015-11-08 19:27:03 -05:00
/**
* Copyright (C) 2015-2016 Lightbend Inc. <http://www.lightbend.com>
2015-11-08 19:27:03 -05:00
*/
package akka.stream.impl
import akka.stream.ThrottleMode.{ Enforcing, Shaping }
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.stage._
import akka.stream._
2016-02-25 11:19:52 +01:00
import akka.util.NanoTimeTokenBucket
2015-11-08 19:27:03 -05:00
import scala.concurrent.duration.{ FiniteDuration, _ }
/**
* INTERNAL API
*/
class Throttle[T](
cost: Int,
per: FiniteDuration,
maximumBurst: Int,
costCalculation: (T) Int,
mode: ThrottleMode)
2015-11-08 19:27:03 -05:00
extends SimpleLinearGraphStage[T] {
2016-01-18 17:49:32 +01:00
require(cost > 0, "cost must be > 0")
2016-02-25 11:19:52 +01:00
require(per.toNanos > 0, "per time must be > 0")
2016-01-18 17:49:32 +01:00
require(!(mode == ThrottleMode.Enforcing && maximumBurst < 0), "maximumBurst must be > 0 in Enforcing mode")
2016-02-25 11:19:52 +01:00
require(per.toNanos >= cost, "Rates larger than 1 unit / nanosecond are not supported")
2015-11-08 19:27:03 -05:00
2016-02-25 11:19:52 +01:00
// There is some loss of precision here because of rounding, but this only happens if nanosBetweenTokens is very
// small which is usually at rates where that precision is highly unlikely anyway as the overhead of this stage
// is likely higher than the required accuracy interval.
private val nanosBetweenTokens = per.toNanos / cost
private val timerName: String = "ThrottleTimer"
2015-11-08 19:27:03 -05:00
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
2016-02-25 11:19:52 +01:00
private val tokenBucket = new NanoTimeTokenBucket(maximumBurst, nanosBetweenTokens)
2015-11-08 19:27:03 -05:00
2016-02-25 11:19:52 +01:00
var willStop = false
var currentElement: T = _
val enforcing = mode match {
case Enforcing true
case Shaping false
}
2015-11-08 19:27:03 -05:00
2016-02-25 11:19:52 +01:00
override def preStart(): Unit = tokenBucket.init()
2015-11-08 19:27:03 -05:00
2016-02-25 11:19:52 +01:00
// This scope is here just to not retain an extra reference to the handler below.
// We can't put this code into preRestart() because setHandler() must be called before that.
{
val handler = new InHandler with OutHandler {
override def onUpstreamFinish(): Unit =
if (isAvailable(out) && isTimerActive(timerName)) willStop = true
else completeStage()
2016-01-23 17:55:03 -05:00
2016-02-25 11:19:52 +01:00
override def onPush(): Unit = {
val elem = grab(in)
val cost = costCalculation(elem)
val delayNanos = tokenBucket.offer(cost)
2016-02-25 11:19:52 +01:00
if (delayNanos == 0L) push(out, elem)
2016-01-23 17:55:03 -05:00
else {
2016-02-25 11:19:52 +01:00
if (enforcing) failStage(new RateExceededException("Maximum throttle throughput exceeded."))
else {
currentElement = elem
scheduleOnce(timerName, delayNanos.nanos)
}
2016-01-23 17:55:03 -05:00
}
2015-11-08 19:27:03 -05:00
}
2016-02-25 11:19:52 +01:00
override def onPull(): Unit = pull(in)
2015-11-08 19:27:03 -05:00
}
2016-02-25 11:19:52 +01:00
setHandler(in, handler)
setHandler(out, handler)
// After this point, we no longer need the `handler` so it can just fall out of scope.
}
2015-11-08 19:27:03 -05:00
override protected def onTimer(key: Any): Unit = {
2016-02-25 11:19:52 +01:00
push(out, currentElement)
currentElement = null.asInstanceOf[T]
2015-11-08 19:27:03 -05:00
if (willStop) completeStage()
}
}
override def toString = "Throttle"
}