diff --git a/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala b/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala index 156b90e626..df7da8bf37 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Throttle.scala @@ -4,20 +4,21 @@ package akka.stream.impl -import scala.concurrent.duration.{ FiniteDuration, _ } - import akka.annotation.InternalApi +import akka.stream.ThrottleMode.Enforcing import akka.stream._ -import akka.stream.ThrottleMode.{ Enforcing, Shaping } import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.stage._ import akka.util.NanoTimeTokenBucket +import scala.concurrent.duration.{ FiniteDuration, _ } + /** * INTERNAL API */ @InternalApi private[akka] object Throttle { final val AutomaticMaximumBurst = -1 + private case object TimerKey } /** @@ -45,57 +46,43 @@ import akka.util.NanoTimeTokenBucket else maximumBurst require(!(mode == ThrottleMode.Enforcing && effectiveMaximumBurst < 0), "maximumBurst must be > 0 in Enforcing mode") - private val timerName: String = "ThrottleTimer" + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new TimerGraphStageLogic(shape) with InHandler with OutHandler { + private val tokenBucket = new NanoTimeTokenBucket(effectiveMaximumBurst, nanosBetweenTokens) + private var currentElement: T = _ - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { - private val tokenBucket = new NanoTimeTokenBucket(effectiveMaximumBurst, nanosBetweenTokens) + override def preStart(): Unit = tokenBucket.init() - var willStop = false - var currentElement: T = _ - val enforcing = mode match { - case Enforcing => true - case Shaping => false - } - - override def preStart(): Unit = tokenBucket.init() - - // This scope is here just to not retain an extra reference to the handler below. - // We can't put this code into preRestart() because setHandler() must be called before that. - { - val handler = new InHandler with OutHandler { - override def onUpstreamFinish(): Unit = - if (isAvailable(out) && isTimerActive(timerName)) willStop = true - else completeStage() - - override def onPush(): Unit = { - val elem = grab(in) - val cost = costCalculation(elem) - val delayNanos = tokenBucket.offer(cost) - - if (delayNanos == 0L) push(out, elem) - else { - if (enforcing) failStage(new RateExceededException("Maximum throttle throughput exceeded.")) - else { - currentElement = elem - scheduleOnce(timerName, delayNanos.nanos) - } - } + override def onUpstreamFinish(): Unit = + if (!(isAvailable(out) && isTimerActive(Throttle.TimerKey))) { + completeStage() } - override def onPull(): Unit = pull(in) + override def onPush(): Unit = { + val elem = grab(in) + val cost = costCalculation(elem) + val delayNanos = tokenBucket.offer(cost) + + if (delayNanos == 0L) push(out, elem) + else { + if (mode eq Enforcing) failStage(new RateExceededException("Maximum throttle throughput exceeded.")) + else { + currentElement = elem + scheduleOnce(Throttle.TimerKey, delayNanos.nanos) + } + } } - setHandlers(in, out, handler) - // After this point, we no longer need the `handler` so it can just fall out of scope. - } + override def onPull(): Unit = pull(in) - override protected def onTimer(key: Any): Unit = { - push(out, currentElement) - currentElement = null.asInstanceOf[T] - if (willStop) completeStage() - } + override protected def onTimer(key: Any): Unit = { + push(out, currentElement) + currentElement = null.asInstanceOf[T] + if (isClosed(in)) completeStage() + } - } + setHandlers(in, out, this) + } override def toString = "Throttle" }