Merge pull request #20376 from danxmoran/danxmoran-20288-pushpull-to-gs

Migrate NaiveTake and ByteStringBatcher to GraphStage #20288.
This commit is contained in:
Patrik Nordwall 2016-04-28 15:19:02 +02:00
commit 4de92a013e

View file

@ -6,7 +6,8 @@ package akka.stream.impl.fusing
import akka.testkit.AkkaSpec
import akka.util.ByteString
import akka.stream.stage._
import akka.stream.Supervision
import akka.stream.{ Attributes, Supervision }
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
import Supervision.stoppingDecider
@ -34,7 +35,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
}
"work with ops that need extra pull for complete" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator
itr.toSeq should be(Seq(1))
}
@ -47,29 +48,9 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
a[NoSuchElementException] should be thrownBy { itr.next() }
}
"throw exceptions when chain fails" in {
val stage = new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) ctx.fail(new ArithmeticException())
else ctx.push(elem)
}
}
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
itr.next() should be(1)
itr.hasNext should be(true)
a[ArithmeticException] should be thrownBy { itr.next() }
itr.hasNext should be(false)
}
"throw exceptions when op in chain throws" in {
val stage = new PushStage[Int, Int] {
override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = {
if (elem == 2) throw new ArithmeticException()
else ctx.push(elem)
}
}
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
Map((n: Int) if (n == 2) throw new ArithmeticException() else n, stoppingDecider).toGS)).iterator
itr.next() should be(1)
itr.hasNext should be(true)
@ -90,7 +71,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
def newItr(threshold: Int) =
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(
ByteStringBatcher(threshold).toGS)).iterator
ByteStringBatcher(threshold))).iterator
val itr1 = newItr(20)
itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
@ -110,7 +91,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
val itr4 =
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(
ByteStringBatcher(10).toGS)).iterator
ByteStringBatcher(10))).iterator
itr4.hasNext should be(false)
}
@ -118,48 +99,65 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
}
// This op needs an extra pull round to finish
case class NaiveTake[T](count: Int) extends PushPullStage[T, T] {
private var left: Int = count
case class NaiveTake[T](count: Int) extends SimpleLinearGraphStage[T] {
override def onPush(elem: T, ctx: Context[T]): SyncDirective = {
left -= 1
ctx.push(elem)
}
override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var left: Int = count
override def onPull(ctx: Context[T]): SyncDirective = {
if (left == 0) ctx.finish()
else ctx.pull()
}
override def onPush(): Unit = {
left -= 1
push(out, grab(in))
}
override def onPull(): Unit = {
if (left == 0) completeStage()
else pull(in)
}
setHandlers(in, out, this)
}
override def toString = "NaiveTake"
}
case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends PushPullStage[ByteString, ByteString] {
case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends SimpleLinearGraphStage[ByteString] {
require(threshold > 0, "Threshold must be positive")
private var buf = ByteString.empty
private var passthrough = false
override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var buf: ByteString = ByteString.empty
private var passthrough: Boolean = false
override def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
if (passthrough) ctx.push(elem)
else {
buf = buf ++ elem
if (buf.size >= threshold) {
val batch = if (compact) buf.compact else buf
passthrough = true
buf = ByteString.empty
ctx.push(batch)
} else ctx.pull()
override def onPush(): Unit = {
val elem = grab(in)
if (passthrough) push(out, elem)
else {
buf = buf ++ elem
if (buf.size >= threshold) {
val batch = if (compact) buf.compact else buf
passthrough = true
buf = ByteString.empty
push(out, batch)
} else pull(in)
}
}
override def onPull(): Unit = {
if (isClosed(in)) {
push(out, buf)
completeStage()
} else pull(in)
}
override def onUpstreamFinish(): Unit = {
if (passthrough || buf.isEmpty) completeStage()
else if (isAvailable(out)) onPull()
}
setHandlers(in, out, this)
}
}
override def onPull(ctx: Context[ByteString]): SyncDirective = {
if (ctx.isFinishing) ctx.pushAndFinish(buf)
else ctx.pull()
}
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = {
if (passthrough || buf.isEmpty) ctx.finish()
else ctx.absorbTermination()
}
override def toString = "ByteStringBatcher"
}
}