diff --git a/akka-stream/src/main/scala/akka/stream/actor/ActorConsumer.scala b/akka-stream/src/main/scala/akka/stream/actor/ActorConsumer.scala new file mode 100644 index 0000000000..e5c9ba967a --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/actor/ActorConsumer.scala @@ -0,0 +1,273 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.actor + +import java.util.concurrent.ConcurrentHashMap +import org.reactivestreams.api.Consumer +import org.reactivestreams.spi.Subscriber +import org.reactivestreams.spi.Subscription +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.ActorSystem +import akka.actor.ExtendedActorSystem +import akka.actor.Extension +import akka.actor.ExtensionId +import akka.actor.ExtensionIdProvider + +object ActorConsumer { + + /** + * Attach a [[ActorConsumer]] actor as a [[org.reactivestreams.Consumer]] + * to a [[org.reactivestreams.Producer]] or [[akka.stream.Flow]]. + */ + def apply[T](ref: ActorRef): Consumer[T] = ActorConsumerImpl(ref) + + /** + * Java API: Attach a [[ActorConsumer]] actor as a [[org.reactivestreams.Consumer]] + * to a [[org.reactivestreams.Producer]] or [[akka.stream.Flow]]. + */ + def create[T](ref: ActorRef): Consumer[T] = apply(ref) + + @SerialVersionUID(1L) case class OnNext(element: Any) + @SerialVersionUID(1L) case object OnComplete + @SerialVersionUID(1L) case class OnError(cause: Throwable) + + /** + * INTERNAL API + */ + @SerialVersionUID(1L) private[akka] case class OnSubscribe(subscription: Subscription) + + /** + * An [[ActorConsumer]] defines a `RequestStrategy` to control the stream back pressure. + */ + trait RequestStrategy { + /** + * Invoked by the [[ActorConsumer]] after each incoming message to + * determine how many more elements to request from the stream. + * + * @param remainingRequested current remaining number of elements that + * have been requested from upstream but not received yet + * @return demand of more elements from the stream, returning 0 means that no + * more elements will be requested + */ + def requestDemand(remainingRequested: Int): Int + } + + /** + * Requests one more element when `remainingRequested` is 0, i.e. + * max one element in flight. + */ + case object OneByOneRequestStrategy extends RequestStrategy { + def requestDemand(remainingRequested: Int): Int = + if (remainingRequested == 0) 1 else 0 + } + + /** + * When request is only controlled with manual calls to + * [[ActorConsumer#request]]. + */ + case object ZeroRequestStrategy extends RequestStrategy { + def requestDemand(remainingRequested: Int): Int = 0 + } + + object WatermarkRequestStrategy { + /** + * Create [[WatermarkRequestStrategy]] with `lowWatermark` as half of + * the specifed `highWatermark`. + */ + def apply(highWatermark: Int): WatermarkRequestStrategy = + WatermarkRequestStrategy(highWatermark, lowWatermark = math.max(1, highWatermark / 2)) + } + + /** + * Requests up to the `highWatermark` when the `remainingRequested` is + * below the `lowWatermark`. This a good strategy when the actor performs work itself. + */ + case class WatermarkRequestStrategy(highWatermark: Int, lowWatermark: Int) extends RequestStrategy { + def requestDemand(remainingRequested: Int): Int = + if (remainingRequested < lowWatermark) + highWatermark - remainingRequested + else 0 + } + + /** + * Requests up to the `max` and also takes the number of messages + * that have been queued internally or delegated to other actors into account. + * Concrete subclass must implement [[#inFlightInternally]]. + * It will request elements in minimum batches of the defined [[#batchSize]]. + */ + abstract class MaxInFlightRequestStrategy(max: Int) extends RequestStrategy { + + /** + * Concrete subclass must implement this method to define how many + * messages that are currently in progress or queued. + */ + def inFlightInternally: Int + + /** + * Elements will be requested in minimum batches of this size. + * Default is 5. Subclass may override to define the batch size. + */ + def batchSize: Int = 5 + + override def requestDemand(remainingRequested: Int): Int = { + val batch = math.min(batchSize, max) + if ((remainingRequested + inFlightInternally) <= (max - batch)) + math.max(0, max - remainingRequested - inFlightInternally) + else 0 + } + } +} + +/** + * Extend/mixin this trait in your [[akka.actor.Actor]] to make it a + * stream consumer with full control of stream back pressure. It will receive + * [[ActorConsumer.OnNext]], [[ActorConsumer.OnComplete]] and [[ActorConsumer.OnError]] + * messages from the stream. It can also receive other, non-stream messages, in + * the same way as any actor. + * + * Attach the actor as a [[org.reactivestreams.Consumer]] to the stream with + * [[ActorConsumer#apply]]. + * + * Subclass must define the [[RequestStrategy]] to control stream back pressure. + * After each incoming message the `ActorConsumer` will automatically invoke + * the [[RequestStrategy#requestDemand]] and propagate the returned demand to the stream. + * The provided [[ActorConsumer.WatermarkRequestStrategy]] is a good strategy if the actor + * performs work itself. + * The provided [[ActorConsumer.MaxInFlightRequestStrategy]] is useful if messages are + * queued internally or delegated to other actors. + * You can also implement a custom [[RequestStrategy]] or call [[#request]] manually + * together with [[ActorConsumer.ZeroRequestStrategy]] or some other strategy. In that case + * you must also call [[#request]] when the actor is started or when it is ready, otherwise + * it will not receive any elements. + */ +trait ActorConsumer extends Actor { + import ActorConsumer._ + + private val state = ActorConsumerState(context.system) + private var subscription: Option[Subscription] = None + private var requested = 0 + private var canceled = false + + protected def requestStrategy: RequestStrategy + + protected[akka] override def aroundReceive(receive: Receive, msg: Any): Unit = msg match { + case _: OnNext ⇒ + requested -= 1 + if (!canceled) { + super.aroundReceive(receive, msg) + request(requestStrategy.requestDemand(requested)) + } + case OnSubscribe(sub) ⇒ + if (subscription.isEmpty) { + subscription = Some(sub) + if (canceled) + sub.cancel() + else if (requested != 0) + sub.requestMore(requested) + } else + sub.cancel() + case _: OnError ⇒ + if (!canceled) super.aroundReceive(receive, msg) + case _ ⇒ + super.aroundReceive(receive, msg) + request(requestStrategy.requestDemand(requested)) + } + + protected[akka] override def aroundPreStart(): Unit = { + super.aroundPreStart() + request(requestStrategy.requestDemand(requested)) + } + + protected[akka] override def aroundPostRestart(reason: Throwable): Unit = { + state.get(self) foreach { s ⇒ + // restore previous state + subscription = s.subscription + requested = s.requested + canceled = s.canceled + } + state.remove(self) + super.aroundPostRestart(reason) + request(requestStrategy.requestDemand(requested)) + } + + protected[akka] override def aroundPreRestart(reason: Throwable, message: Option[Any]): Unit = { + // some state must survive restart + state.set(self, ActorConsumerState.State(subscription, requested, canceled)) + super.aroundPreRestart(reason, message) + } + + protected[akka] override def aroundPostStop(): Unit = { + state.remove(self) + if (!canceled) subscription.foreach(_.cancel()) + super.aroundPostStop() + } + + /** + * Request a number of elements from upstream. + */ + protected def request(elements: Int): Unit = + if (elements > 0 && !canceled) { + // if we don't have a subscription yet, it will be requested when it arrives + subscription.foreach(_.requestMore(elements)) + requested += elements + } + + /** + * Cancel upstream subscription. No more elements will + * be delivered after cancel. + */ + protected def cancel(): Unit = { + subscription.foreach(_.cancel()) + canceled = true + } + +} + +/** + * INTERNAL API + */ +private[akka] case class ActorConsumerImpl[T](ref: ActorRef) extends Consumer[T] { + override val getSubscriber: Subscriber[T] = new ActorSubscriber[T](ref) +} + +/** + * INTERNAL API + */ +private[akka] final class ActorSubscriber[T](val impl: ActorRef) extends Subscriber[T] { + override def onError(cause: Throwable): Unit = impl ! ActorConsumer.OnError(cause) + override def onComplete(): Unit = impl ! ActorConsumer.OnComplete + override def onNext(element: T): Unit = impl ! ActorConsumer.OnNext(element) + override def onSubscribe(subscription: Subscription): Unit = impl ! ActorConsumer.OnSubscribe(subscription) +} + +/** + * INTERNAL API + * Some state must survive restarts. + */ +private[akka] object ActorConsumerState extends ExtensionId[ActorConsumerState] with ExtensionIdProvider { + override def get(system: ActorSystem): ActorConsumerState = super.get(system) + + override def lookup = ActorConsumerState + + override def createExtension(system: ExtendedActorSystem): ActorConsumerState = + new ActorConsumerState + + case class State(subscription: Option[Subscription], requested: Int, canceled: Boolean) + +} + +/** + * INTERNAL API + */ +private[akka] class ActorConsumerState extends Extension { + import ActorConsumerState.State + private val state = new ConcurrentHashMap[ActorRef, State] + + def get(ref: ActorRef): Option[State] = Option(state.get(ref)) + + def set(ref: ActorRef, s: State): Unit = state.put(ref, s) + + def remove(ref: ActorRef): Unit = state.remove(ref) +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala index f8d58d1689..aeb3aa97f6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala @@ -21,6 +21,7 @@ import akka.actor.ExtensionId import akka.actor.ExtendedActorSystem import akka.actor.ActorSystem import akka.actor.Extension +import akka.stream.actor.ActorConsumer /** * INTERNAL API @@ -195,11 +196,11 @@ private[akka] class ActorBasedFlowMaterializer( private def consume[In, Out](ops: List[Ast.AstNode], flowName: String): Consumer[In] = { val c = ops match { case Nil ⇒ - new ActorConsumer[Any](context.actorOf(ActorConsumer.props(settings, blackholeTransform), + ActorConsumer[Any](context.actorOf(ActorConsumerProps.props(settings, blackholeTransform), name = s"$flowName-1-consume")) case head :: tail ⇒ val opsSize = ops.size - val c = new ActorConsumer[Any](context.actorOf(ActorConsumer.props(settings, head), + val c = ActorConsumer[Any](context.actorOf(ActorConsumerProps.props(settings, head), name = s"$flowName-$opsSize-${head.name}")) processorChain(c, tail, flowName, ops.size - 1) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorConsumer.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorConsumer.scala index 146ecb1034..39406f9d9a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorConsumer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorConsumer.scala @@ -12,34 +12,12 @@ import Ast.{ AstNode, Transform } import akka.actor.{ Actor, ActorLogging, ActorRef, Props, actorRef2Scala } import akka.stream.MaterializerSettings import akka.stream.Transformer +import akka.stream.actor.ActorConsumer.{ OnNext, OnError, OnComplete, OnSubscribe } /** * INTERNAL API */ -private[akka] class ActorSubscriber[T]( final val impl: ActorRef) extends Subscriber[T] { - override def onError(cause: Throwable): Unit = impl ! OnError(cause) - override def onComplete(): Unit = impl ! OnComplete - override def onNext(element: T): Unit = impl ! OnNext(element) - override def onSubscribe(subscription: Subscription): Unit = impl ! OnSubscribe(subscription) -} - -/** - * INTERNAL API - */ -private[akka] trait ActorConsumerLike[T] extends Consumer[T] { - def impl: ActorRef - override val getSubscriber: Subscriber[T] = new ActorSubscriber[T](impl) -} - -/** - * INTERNAL API - */ -private[akka] class ActorConsumer[T]( final val impl: ActorRef) extends ActorConsumerLike[T] - -/** - * INTERNAL API - */ -private[akka] object ActorConsumer { +private[akka] object ActorConsumerProps { import Ast._ def props(settings: MaterializerSettings, op: AstNode) = op match { diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala index 19e4dbda6f..2c700e7440 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala @@ -4,12 +4,16 @@ package akka.stream.impl import org.reactivestreams.api.Processor -import org.reactivestreams.spi.{ Subscription, Subscriber } +import org.reactivestreams.spi.Subscriber import akka.actor._ import akka.stream.MaterializerSettings import akka.event.LoggingReceive import java.util.Arrays import scala.util.control.NonFatal +import org.reactivestreams.api.Consumer +import akka.stream.actor.ActorSubscriber +import akka.stream.actor.ActorConsumer.{ OnNext, OnError, OnComplete, OnSubscribe } +import org.reactivestreams.spi.Subscription /** * INTERNAL API @@ -32,7 +36,9 @@ private[akka] object ActorProcessor { /** * INTERNAL API */ -private[akka] class ActorProcessor[I, O]( final val impl: ActorRef) extends Processor[I, O] with ActorConsumerLike[I] with ActorProducerLike[O] +private[akka] class ActorProcessor[I, O]( final val impl: ActorRef) extends Processor[I, O] with Consumer[I] with ActorProducerLike[O] { + override val getSubscriber: Subscriber[I] = new ActorSubscriber[I](impl) +} /** * INTERNAL API diff --git a/akka-stream/src/main/scala/akka/stream/impl/Messages.scala b/akka-stream/src/main/scala/akka/stream/impl/Messages.scala index e03b3a5278..d78544f794 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Messages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Messages.scala @@ -5,18 +5,20 @@ package akka.stream.impl import org.reactivestreams.spi.Subscription -// FIXME INTERNAL API - -case class OnSubscribe(subscription: Subscription) -// TODO performance improvement: skip wrapping ordinary elements in OnNext -case class OnNext(element: Any) -case object OnComplete -case class OnError(cause: Throwable) - -case object SubscribePending - -case class RequestMore(subscription: ActorSubscription[_], demand: Int) -case class Cancel(subscriptions: ActorSubscription[_]) - -case class ExposedPublisher(publisher: ActorPublisher[Any]) +/** + * INTERNAL API + */ +private[akka] case object SubscribePending +/** + * INTERNAL API + */ +private[akka] case class RequestMore(subscription: ActorSubscription[_], demand: Int) +/** + * INTERNAL API + */ +private[akka] case class Cancel(subscriptions: ActorSubscription[_]) +/** + * INTERNAL API + */ +private[akka] case class ExposedPublisher(publisher: ActorPublisher[Any]) diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala index 0011eacf8b..23aa1a704f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala @@ -7,6 +7,7 @@ import akka.stream.MaterializerSettings import akka.actor.{ Actor, Terminated, ActorRef } import org.reactivestreams.spi.{ Subscriber, Subscription } import org.reactivestreams.api.Producer +import akka.stream.actor.ActorConsumer.{ OnNext, OnError, OnComplete, OnSubscribe } /** * INTERNAL API diff --git a/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala index 0ccd7ffc9e..493adac17b 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala @@ -6,8 +6,7 @@ package akka.stream import akka.stream.testkit.AkkaSpec import akka.stream.testkit.ScriptedTest import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -import akka.stream.impl.OnNext -import akka.stream.impl.OnComplete +import akka.stream.actor.ActorConsumer.{ OnNext, OnComplete } import akka.stream.impl.RequestMore class FlowTakeSpec extends AkkaSpec with ScriptedTest { diff --git a/akka-stream/src/test/scala/akka/stream/actor/ActorConsumerSpec.scala b/akka-stream/src/test/scala/akka/stream/actor/ActorConsumerSpec.scala new file mode 100644 index 0000000000..50db444ec0 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/actor/ActorConsumerSpec.scala @@ -0,0 +1,273 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.actor + +import scala.concurrent.duration._ +import akka.actor.ActorRef +import akka.actor.Props +import akka.stream.FlowMaterializer +import akka.stream.MaterializerSettings +import akka.stream.scaladsl.Flow +import akka.stream.testkit.AkkaSpec +import akka.stream.actor.ActorConsumer.RequestStrategy +import akka.actor.Actor +import akka.routing.ActorRefRoutee +import akka.routing.Router +import akka.routing.RoundRobinRoutingLogic +import akka.testkit.ImplicitSender +import scala.util.control.NoStackTrace + +object ActorConsumerSpec { + + def manualConsumerProps(probe: ActorRef): Props = + Props(new ManualConsumer(probe)).withDispatcher("akka.test.stream-dispatcher") + + class ManualConsumer(probe: ActorRef) extends ActorConsumer { + import ActorConsumer._ + + override val requestStrategy = ZeroRequestStrategy + + def receive = { + case next @ OnNext(elem) ⇒ probe ! next + case complete @ OnComplete ⇒ probe ! complete + case err @ OnError(cause) ⇒ probe ! err + case "ready" ⇒ request(elements = 2) + case "boom" ⇒ throw new RuntimeException("boom") with NoStackTrace + case "requestAndCancel" ⇒ { request(1); cancel() } + } + } + + def requestStrategyConsumerProps(probe: ActorRef, strat: RequestStrategy): Props = + Props(new RequestStrategyConsumer(probe, strat)).withDispatcher("akka.test.stream-dispatcher") + + class RequestStrategyConsumer(probe: ActorRef, strat: RequestStrategy) extends ActorConsumer { + import ActorConsumer._ + + override val requestStrategy = strat + + def receive = { + case next @ OnNext(elem) ⇒ probe ! next + case complete @ OnComplete ⇒ probe ! complete + } + } + + case class Msg(id: Int, replyTo: ActorRef) + case class Work(id: Int) + case class Reply(id: Int) + case class Done(id: Int) + + def streamerProps: Props = + Props(new Streamer).withDispatcher("akka.test.stream-dispatcher") + + class Streamer extends ActorConsumer { + import ActorConsumer._ + var queue = Map.empty[Int, ActorRef] + + val router = { + val routees = Vector.fill(3) { + ActorRefRoutee(context.actorOf(Props[Worker].withDispatcher(context.props.dispatcher))) + } + Router(RoundRobinRoutingLogic(), routees) + } + + override val requestStrategy = new MaxInFlightRequestStrategy(max = 10) { + override def inFlightInternally: Int = queue.size + } + + def receive = { + case OnNext(Msg(id, replyTo)) ⇒ + queue += (id -> replyTo) + assert(queue.size <= 10, s"queued too many: ${queue.size}") + router.route(Work(id), self) + case Reply(id) ⇒ + queue(id) ! Done(id) + queue -= id + } + } + + class Worker extends Actor { + def receive = { + case Work(id) ⇒ + // ... + sender() ! Reply(id) + } + } +} + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class ActorConsumerSpec extends AkkaSpec with ImplicitSender { + import ActorConsumerSpec._ + import ActorConsumer._ + + val materializer = FlowMaterializer(MaterializerSettings(dispatcher = "akka.test.stream-dispatcher")) + + "An ActorConsumer" must { + + "receive requested elements" in { + val ref = system.actorOf(manualConsumerProps(testActor)) + Flow(List(1, 2, 3)).produceTo(materializer, ActorConsumer(ref)) + expectNoMsg(200.millis) + ref ! "ready" // requesting 2 + expectMsg(OnNext(1)) + expectMsg(OnNext(2)) + expectNoMsg(200.millis) + ref ! "ready" + expectMsg(OnNext(3)) + expectMsg(OnComplete) + } + + "signal error" in { + val ref = system.actorOf(manualConsumerProps(testActor)) + val e = new RuntimeException("simulated") with NoStackTrace + Flow(() ⇒ throw e).produceTo(materializer, ActorConsumer(ref)) + ref ! "ready" + expectMsg(OnError(e)) + } + + "remember requested after restart" in { + val ref = system.actorOf(manualConsumerProps(testActor)) + Flow(1 to 7).produceTo(materializer, ActorConsumer(ref)) + ref ! "ready" + expectMsg(OnNext(1)) + expectMsg(OnNext(2)) + expectNoMsg(200.millis) // nothing requested + ref ! "boom" + ref ! "ready" + ref ! "ready" + ref ! "boom" + (3 to 6) foreach { n ⇒ expectMsg(OnNext(n)) } + expectNoMsg(200.millis) + ref ! "ready" + expectMsg(OnNext(7)) + expectMsg(OnComplete) + } + + "not deliver more after cancel" in { + val ref = system.actorOf(manualConsumerProps(testActor)) + Flow(1 to 5).produceTo(materializer, ActorConsumer(ref)) + ref ! "ready" + expectMsg(OnNext(1)) + expectMsg(OnNext(2)) + ref ! "requestAndCancel" + expectNoMsg(200.millis) + } + + "work with OneByOneRequestStrategy" in { + val ref = system.actorOf(requestStrategyConsumerProps(testActor, OneByOneRequestStrategy)) + Flow(1 to 17).produceTo(materializer, ActorConsumer(ref)) + for (n ← 1 to 17) expectMsg(OnNext(n)) + expectMsg(OnComplete) + } + + "work with WatermarkRequestStrategy" in { + val ref = system.actorOf(requestStrategyConsumerProps(testActor, WatermarkRequestStrategy(highWatermark = 10))) + Flow(1 to 17).produceTo(materializer, ActorConsumer(ref)) + for (n ← 1 to 17) expectMsg(OnNext(n)) + expectMsg(OnComplete) + } + + "suport custom max in flight request strategy with child workers" in { + val ref = system.actorOf(streamerProps) + val N = 117 + Flow(1 to N).map(Msg(_, testActor)).produceTo(materializer, ActorConsumer(ref)) + receiveN(N).toSet should be((1 to N).map(Done(_)).toSet) + } + + } + + "Provided RequestStragies" must { + "implement OneByOne correctly" in { + val strat = OneByOneRequestStrategy + strat.requestDemand(0) should be(1) + strat.requestDemand(1) should be(0) + strat.requestDemand(2) should be(0) + } + + "implement Zero correctly" in { + val strat = ZeroRequestStrategy + strat.requestDemand(0) should be(0) + strat.requestDemand(1) should be(0) + strat.requestDemand(2) should be(0) + } + + "implement Watermark correctly" in { + val strat = WatermarkRequestStrategy(highWatermark = 10) + strat.requestDemand(0) should be(10) + strat.requestDemand(9) should be(0) + strat.requestDemand(6) should be(0) + strat.requestDemand(5) should be(0) + strat.requestDemand(4) should be(6) + } + + "implement MaxInFlight with batchSize=1 correctly" in { + var queue = Set.empty[String] + val strat = new MaxInFlightRequestStrategy(max = 10) { + override def batchSize: Int = 1 + def inFlightInternally: Int = queue.size + } + strat.requestDemand(0) should be(10) + strat.requestDemand(9) should be(1) + queue += "a" + strat.requestDemand(0) should be(9) + strat.requestDemand(8) should be(1) + strat.requestDemand(9) should be(0) + queue += "b" + queue += "c" + strat.requestDemand(5) should be(2) + ('d' to 'j') foreach { queue += _.toString } + queue.size should be(10) + strat.requestDemand(0) should be(0) + strat.requestDemand(1) should be(0) + queue += "g" + strat.requestDemand(0) should be(0) + strat.requestDemand(1) should be(0) + } + + "implement MaxInFlight with batchSize=3 correctly" in { + var queue = Set.empty[String] + val strat = new MaxInFlightRequestStrategy(max = 10) { + override def batchSize: Int = 3 + override def inFlightInternally: Int = queue.size + } + strat.requestDemand(0) should be(10) + queue += "a" + strat.requestDemand(9) should be(0) + queue += "b" + strat.requestDemand(8) should be(0) + queue += "c" + strat.requestDemand(7) should be(0) + queue += "d" + strat.requestDemand(6) should be(0) + queue -= "a" // 3 remaining in queue + strat.requestDemand(6) should be(0) + queue -= "b" // 2 remaining in queue + strat.requestDemand(6) should be(0) + queue -= "c" // 1 remaining in queue + strat.requestDemand(6) should be(3) + } + + "implement MaxInFlight with batchSize=max correctly" in { + var queue = Set.empty[String] + val strat = new MaxInFlightRequestStrategy(max = 3) { + override def batchSize: Int = 5 // will be bounded to max + override def inFlightInternally: Int = queue.size + } + strat.requestDemand(0) should be(3) + queue += "a" + strat.requestDemand(2) should be(0) + queue += "b" + strat.requestDemand(1) should be(0) + queue += "c" + strat.requestDemand(0) should be(0) + queue -= "a" + strat.requestDemand(0) should be(0) + queue -= "b" + strat.requestDemand(0) should be(0) + queue -= "c" + strat.requestDemand(0) should be(3) + } + + } + +}