allow Sink.queue concurrent pulling (#27352)
* allow Sink.queue concurrent pulling * replace methods with default parameters on two overloaded methods to pass binary compatibility check :/ * replace ⇒ with => * reformat * add javadsl * fix PR comments and add concurrency to Sink.queue * fix merge after auto resolving * duplicate changes to javadsl * revert source changes * add graceful terminations * clean up tests * optimize imports * trigger rebuild * cover the case when materializer shutdown before async callbacks were processed * vars to vals; fix require messages * disable compatibility check for @InternalApi private[akka] class
This commit is contained in:
parent
83452be2ff
commit
a614f0bee7
6 changed files with 148 additions and 60 deletions
|
|
@ -311,8 +311,11 @@ import scala.util.control.NonFatal
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
@InternalApi private[akka] final class QueueSink[T]()
|
||||
@InternalApi private[akka] final class QueueSink[T](maxConcurrentPulls: Int)
|
||||
extends GraphStageWithMaterializedValue[SinkShape[T], SinkQueueWithCancel[T]] {
|
||||
|
||||
require(maxConcurrentPulls > 0, "Max concurrent pulls must be greater than 0")
|
||||
|
||||
type Requested[E] = Promise[Option[E]]
|
||||
|
||||
val in = Inlet[T]("queueSink.in")
|
||||
|
|
@ -328,30 +331,25 @@ import scala.util.control.NonFatal
|
|||
val maxBuffer = inheritedAttributes.get[InputBuffer](InputBuffer(16, 16)).max
|
||||
require(maxBuffer > 0, "Buffer size must be greater than 0")
|
||||
|
||||
var buffer: Buffer[Received[T]] = _
|
||||
var currentRequest: Option[Requested[T]] = None
|
||||
// Allocates one additional element to hold stream closed/failure indicators
|
||||
val buffer: Buffer[Received[T]] = Buffer(maxBuffer + 1, inheritedAttributes)
|
||||
val currentRequests: Buffer[Requested[T]] = Buffer(maxConcurrentPulls, inheritedAttributes)
|
||||
|
||||
override def preStart(): Unit = {
|
||||
// Allocates one additional element to hold stream
|
||||
// closed/failure indicators
|
||||
buffer = Buffer(maxBuffer + 1, inheritedAttributes)
|
||||
setKeepGoing(true)
|
||||
pull(in)
|
||||
}
|
||||
|
||||
private val callback = getAsyncCallback[Output[T]] {
|
||||
case QueueSink.Pull(pullPromise) =>
|
||||
currentRequest match {
|
||||
case Some(_) =>
|
||||
pullPromise.failure(
|
||||
new IllegalStateException(
|
||||
"You have to wait for previous future to be resolved to send another request"))
|
||||
case None =>
|
||||
if (buffer.isEmpty) currentRequest = Some(pullPromise)
|
||||
else {
|
||||
if (buffer.used == maxBuffer) tryPull(in)
|
||||
sendDownstream(pullPromise)
|
||||
}
|
||||
if (currentRequests.isFull)
|
||||
pullPromise.failure(
|
||||
new IllegalStateException(s"Too many concurrent pulls. Specified maximum is $maxConcurrentPulls. " +
|
||||
"You have to wait for one previous future to be resolved to send another request"))
|
||||
else if (buffer.isEmpty) currentRequests.enqueue(pullPromise)
|
||||
else {
|
||||
if (buffer.used == maxBuffer) tryPull(in)
|
||||
sendDownstream(pullPromise)
|
||||
}
|
||||
case QueueSink.Cancel => completeStage()
|
||||
}
|
||||
|
|
@ -366,23 +364,28 @@ import scala.util.control.NonFatal
|
|||
}
|
||||
}
|
||||
|
||||
def enqueueAndNotify(requested: Received[T]): Unit = {
|
||||
buffer.enqueue(requested)
|
||||
currentRequest match {
|
||||
case Some(p) =>
|
||||
sendDownstream(p)
|
||||
currentRequest = None
|
||||
case None => //do nothing
|
||||
}
|
||||
}
|
||||
|
||||
def onPush(): Unit = {
|
||||
enqueueAndNotify(Success(Some(grab(in))))
|
||||
buffer.enqueue(Success(Some(grab(in))))
|
||||
if (currentRequests.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
|
||||
if (buffer.used < maxBuffer) pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = enqueueAndNotify(Success(None))
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = enqueueAndNotify(Failure(ex))
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
buffer.enqueue(Success(None))
|
||||
while (currentRequests.nonEmpty && buffer.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
|
||||
while (currentRequests.nonEmpty) currentRequests.dequeue().complete(Success(None))
|
||||
if (buffer.isEmpty) completeStage()
|
||||
}
|
||||
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = {
|
||||
buffer.enqueue(Failure(ex))
|
||||
while (currentRequests.nonEmpty && buffer.nonEmpty) currentRequests.dequeue().complete(buffer.dequeue())
|
||||
while (currentRequests.nonEmpty) currentRequests.dequeue().complete(Failure(ex))
|
||||
if (buffer.isEmpty) failStage(ex)
|
||||
}
|
||||
|
||||
override def postStop(): Unit =
|
||||
while (currentRequests.nonEmpty) currentRequests.dequeue().failure(new AbruptStageTerminationException(this))
|
||||
|
||||
setHandler(in, this)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue