Merge pull request #20376 from danxmoran/danxmoran-20288-pushpull-to-gs
Migrate NaiveTake and ByteStringBatcher to GraphStage #20288.
This commit is contained in:
commit
4de92a013e
1 changed files with 58 additions and 60 deletions
|
|
@ -6,7 +6,8 @@ package akka.stream.impl.fusing
|
|||
import akka.testkit.AkkaSpec
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage._
|
||||
import akka.stream.Supervision
|
||||
import akka.stream.{ Attributes, Supervision }
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
|
||||
class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
||||
import Supervision.stoppingDecider
|
||||
|
|
@ -34,7 +35,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
}
|
||||
|
||||
"work with ops that need extra pull for complete" in {
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator
|
||||
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator
|
||||
|
||||
itr.toSeq should be(Seq(1))
|
||||
}
|
||||
|
|
@ -47,29 +48,9 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
a[NoSuchElementException] should be thrownBy { itr.next() }
|
||||
}
|
||||
|
||||
"throw exceptions when chain fails" in {
|
||||
val stage = new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) ctx.fail(new ArithmeticException())
|
||||
else ctx.push(elem)
|
||||
}
|
||||
}
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
|
||||
|
||||
itr.next() should be(1)
|
||||
itr.hasNext should be(true)
|
||||
a[ArithmeticException] should be thrownBy { itr.next() }
|
||||
itr.hasNext should be(false)
|
||||
}
|
||||
|
||||
"throw exceptions when op in chain throws" in {
|
||||
val stage = new PushStage[Int, Int] {
|
||||
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
|
||||
if (elem == 2) throw new ArithmeticException()
|
||||
else ctx.push(elem)
|
||||
}
|
||||
}
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
|
||||
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
|
||||
Map((n: Int) ⇒ if (n == 2) throw new ArithmeticException() else n, stoppingDecider).toGS)).iterator
|
||||
|
||||
itr.next() should be(1)
|
||||
itr.hasNext should be(true)
|
||||
|
|
@ -90,7 +71,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
|
||||
def newItr(threshold: Int) =
|
||||
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(
|
||||
ByteStringBatcher(threshold).toGS)).iterator
|
||||
ByteStringBatcher(threshold))).iterator
|
||||
|
||||
val itr1 = newItr(20)
|
||||
itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
|
||||
|
|
@ -110,7 +91,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
|
||||
val itr4 =
|
||||
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(
|
||||
ByteStringBatcher(10).toGS)).iterator
|
||||
ByteStringBatcher(10))).iterator
|
||||
|
||||
itr4.hasNext should be(false)
|
||||
}
|
||||
|
|
@ -118,48 +99,65 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
}
|
||||
|
||||
// This op needs an extra pull round to finish
|
||||
case class NaiveTake[T](count: Int) extends PushPullStage[T, T] {
|
||||
private var left: Int = count
|
||||
case class NaiveTake[T](count: Int) extends SimpleLinearGraphStage[T] {
|
||||
|
||||
override def onPush(elem: T, ctx: Context[T]): SyncDirective = {
|
||||
left -= 1
|
||||
ctx.push(elem)
|
||||
}
|
||||
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var left: Int = count
|
||||
|
||||
override def onPull(ctx: Context[T]): SyncDirective = {
|
||||
if (left == 0) ctx.finish()
|
||||
else ctx.pull()
|
||||
}
|
||||
override def onPush(): Unit = {
|
||||
left -= 1
|
||||
push(out, grab(in))
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
if (left == 0) completeStage()
|
||||
else pull(in)
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def toString = "NaiveTake"
|
||||
}
|
||||
|
||||
case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends PushPullStage[ByteString, ByteString] {
|
||||
case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends SimpleLinearGraphStage[ByteString] {
|
||||
require(threshold > 0, "Threshold must be positive")
|
||||
|
||||
private var buf = ByteString.empty
|
||||
private var passthrough = false
|
||||
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var buf: ByteString = ByteString.empty
|
||||
private var passthrough: Boolean = false
|
||||
|
||||
override def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
if (passthrough) ctx.push(elem)
|
||||
else {
|
||||
buf = buf ++ elem
|
||||
if (buf.size >= threshold) {
|
||||
val batch = if (compact) buf.compact else buf
|
||||
passthrough = true
|
||||
buf = ByteString.empty
|
||||
ctx.push(batch)
|
||||
} else ctx.pull()
|
||||
override def onPush(): Unit = {
|
||||
val elem = grab(in)
|
||||
if (passthrough) push(out, elem)
|
||||
else {
|
||||
buf = buf ++ elem
|
||||
if (buf.size >= threshold) {
|
||||
val batch = if (compact) buf.compact else buf
|
||||
passthrough = true
|
||||
buf = ByteString.empty
|
||||
push(out, batch)
|
||||
} else pull(in)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = {
|
||||
if (isClosed(in)) {
|
||||
push(out, buf)
|
||||
completeStage()
|
||||
} else pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (passthrough || buf.isEmpty) completeStage()
|
||||
else if (isAvailable(out)) onPull()
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective = {
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(buf)
|
||||
else ctx.pull()
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = {
|
||||
if (passthrough || buf.isEmpty) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
}
|
||||
override def toString = "ByteStringBatcher"
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue