allow Sink.queue concurrent pulling (#27352)

* allow Sink.queue concurrent pulling

* replace methods with default parameters on two overloaded methods to pass binary compatibility check :/

* replace ⇒ with =>

* reformat

* add javadsl

* fix PR comments and add concurrency to Sink.queue

* fix merge after auto resolving

* duplicate changes to javadsl

* revert source changes

* add graceful terminations

* clean up tests

* optimize imports

* trigger rebuild

* cover the case when materializer shutdown before async callbacks were processed

* vars to vals; fix require messages

* disable compatibility check for @InternalApi private[akka] class
This commit is contained in:
Yakiv Yereskovskyi 2020-01-25 03:33:39 +07:00 committed by Patrik Nordwall
parent 83452be2ff
commit a614f0bee7
6 changed files with 148 additions and 60 deletions

View file

@ -311,8 +311,11 @@ import scala.util.control.NonFatal
/**
* INTERNAL API
*/
@InternalApi private[akka] final class QueueSink[T]()
@InternalApi private[akka] final class QueueSink[T](maxConcurrentPulls: Int)
extends GraphStageWithMaterializedValue[SinkShape[T], SinkQueueWithCancel[T]] {
require(maxConcurrentPulls > 0, "Max concurrent pulls must be greater than 0")
type Requested[E] = Promise[Option[E]]
val in = Inlet[T]("queueSink.in")
@ -328,30 +331,25 @@ import scala.util.control.NonFatal
val maxBuffer = inheritedAttributes.get[InputBuffer](InputBuffer(16, 16)).max
require(maxBuffer > 0, "Buffer size must be greater than 0")
var buffer: Buffer[Received[T]] = _
var currentRequest: Option[Requested[T]] = None
// Allocates one additional element to hold stream closed/failure indicators
val buffer: Buffer[Received[T]] = Buffer(maxBuffer + 1, inheritedAttributes)
val currentRequests: Buffer[Requested[T]] = Buffer(maxConcurrentPulls, inheritedAttributes)
override def preStart(): Unit = {
// Allocates one additional element to hold stream
// closed/failure indicators
buffer = Buffer(maxBuffer + 1, inheritedAttributes)
setKeepGoing(true)
pull(in)
}
private val callback = getAsyncCallback[Output[T]] {
case QueueSink.Pull(pullPromise) =>
currentRequest match {
case Some(_) =>
pullPromise.failure(
new IllegalStateException(
"You have to wait for previous future to be resolved to send another request"))
case None =>
if (buffer.isEmpty) currentRequest = Some(pullPromise)
else {
if (buffer.used == maxBuffer) tryPull(in)
sendDownstream(pullPromise)
}
if (currentRequests.isFull)
pullPromise.failure(
new IllegalStateException(s"Too many concurrent pulls. Specified maximum is $maxConcurrentPulls. " +
"You have to wait for one previous future to be resolved to send another request"))
else if (buffer.isEmpty) currentRequests.enqueue(pullPromise)
else {
if (buffer.used == maxBuffer) tryPull(in)
sendDownstream(pullPromise)
}
case QueueSink.Cancel => completeStage()
}
@ -366,23 +364,28 @@ import scala.util.control.NonFatal
}
}
def enqueueAndNotify(requested: Received[T]): Unit = {
buffer.enqueue(requested)
currentRequest match {
case Some(p) =>
sendDownstream(p)
currentRequest = None
case None => //do nothing
}
}
def onPush(): Unit = {
enqueueAndNotify(Success(Some(grab(in))))
buffer.enqueue(Success(Some(grab(in))))
if (currentRequests.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
if (buffer.used < maxBuffer) pull(in)
}
override def onUpstreamFinish(): Unit = enqueueAndNotify(Success(None))
override def onUpstreamFailure(ex: Throwable): Unit = enqueueAndNotify(Failure(ex))
override def onUpstreamFinish(): Unit = {
buffer.enqueue(Success(None))
while (currentRequests.nonEmpty && buffer.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
while (currentRequests.nonEmpty) currentRequests.dequeue().complete(Success(None))
if (buffer.isEmpty) completeStage()
}
override def onUpstreamFailure(ex: Throwable): Unit = {
buffer.enqueue(Failure(ex))
while (currentRequests.nonEmpty && buffer.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
while (currentRequests.nonEmpty) currentRequests.dequeue().complete(Failure(ex))
if (buffer.isEmpty) failStage(ex)
}
override def postStop(): Unit =
while (currentRequests.nonEmpty) currentRequests.dequeue().failure(new AbruptStageTerminationException(this))
setHandler(in, this)