!str add transform’s isComplete and use it for Take

This commit is contained in:
Roland Kuhn 2014-03-29 17:52:48 +01:00
parent f70a1b2b4a
commit bcd0941ff2
5 changed files with 96 additions and 11 deletions

View file

@ -22,9 +22,13 @@ trait Stream[T] {
def foreach(c: T Unit): Stream[Unit]
def fold[U](zero: U)(f: (U, T) U): Stream[U]
def drop(n: Int): Stream[T]
def take(n: Int): Stream[T]
def grouped(n: Int): Stream[immutable.Seq[T]]
def mapConcat[U](f: T immutable.Seq[U]): Stream[U]
def transform[S, U](zero: S)(f: (S, T) (S, immutable.Seq[U]), onComplete: S immutable.Seq[U] = (_: S) Nil): Stream[U]
def transform[S, U](zero: S)(
f: (S, T) (S, immutable.Seq[U]),
onComplete: S immutable.Seq[U] = (_: S) Nil,
isComplete: S Boolean = (_: S) false): Stream[U]
def toProducer(generator: ProcessorGenerator): Producer[T]
}

View file

@ -12,6 +12,7 @@ import scala.util.control.NonFatal
import akka.stream.{ Stream, GeneratorSettings, ProcessorGenerator }
import scala.collection.immutable
import akka.actor.ActorLogging
import java.util.Arrays
/**
* INTERNAL API
@ -19,7 +20,7 @@ import akka.actor.ActorLogging
private[akka] object Ast {
trait AstNode
case class Transform(zero: Any, f: (Any, Any) (Any, immutable.Seq[Any]), onComplete: Any immutable.Seq[Any]) extends AstNode
case class Transform(zero: Any, f: (Any, Any) (Any, immutable.Seq[Any]), onComplete: Any immutable.Seq[Any], isComplete: Any Boolean) extends AstNode
}
/**
@ -40,6 +41,8 @@ private[akka] case class StreamImpl[I, O](producer: Producer[I], ops: List[Ast.A
def drop(n: Int): Stream[O] = transform(n)((x, in) if (x == 0) 0 -> List(in) else (x - 1) -> Nil)
def take(n: Int): Stream[O] = transform(n)((x, in) if (x == 0) 0 -> Nil else (x - 1) -> List(in), isComplete = _ == 0)
def grouped(n: Int): Stream[immutable.Seq[O]] =
transform(immutable.Seq.empty[O])((buf, in) {
val group = buf :+ in
@ -49,11 +52,15 @@ private[akka] case class StreamImpl[I, O](producer: Producer[I], ops: List[Ast.A
def mapConcat[U](f: O immutable.Seq[U]): Stream[U] = transform(())((_, in) ((), f(in)))
def transform[S, U](zero: S)(f: (S, O) (S, immutable.Seq[U]), onComplete: S immutable.Seq[U] = (_: S) Nil): Stream[U] =
def transform[S, U](zero: S)(
f: (S, O) (S, immutable.Seq[U]),
onComplete: S immutable.Seq[U] = (_: S) Nil,
isComplete: S Boolean = (_: S) false): Stream[U] =
andThen(Transform(
zero,
f.asInstanceOf[(Any, Any) (Any, immutable.Seq[Any])],
onComplete.asInstanceOf[Any immutable.Seq[Any]]))
onComplete.asInstanceOf[Any immutable.Seq[Any]],
isComplete.asInstanceOf[Any Boolean]))
def toProducer(generator: ProcessorGenerator): Producer[O] = generator.toProducer(producer, ops)
}
@ -266,7 +273,7 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings)
}
private var downstreamBufferSpace = 0
private var inputBuffer = Array.ofDim[Any](settings.initialInputBufferSize)
private var inputBuffer = Array.ofDim[AnyRef](settings.initialInputBufferSize)
private var inputBufferElements = 0
private var nextInputElementCursor = 0
val IndexMask = settings.initialInputBufferSize - 1
@ -276,7 +283,8 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings)
private var batchRemaining = requestBatchSize
def dequeueInputElement(): Any = {
val elem = inputBuffer(nextInputElementCursor & IndexMask)
val elem = inputBuffer(nextInputElementCursor)
inputBuffer(nextInputElementCursor) = null
batchRemaining -= 1
if (batchRemaining == 0 && !upstreamCompleted) {
@ -291,7 +299,7 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings)
}
def enqueueInputElement(elem: Any): Unit = {
inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem
inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef]
inputBufferElements += 1
}
@ -336,7 +344,13 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings)
} catch { case NonFatal(e) fail(e) }
if (transferState.isCompleted) {
isShuttingDown = true
if (!isShuttingDown) {
if (!upstreamCompleted) upstream.cancel()
Arrays.fill(inputBuffer, nextInputElementCursor, nextInputElementCursor + inputBufferElements, null)
inputBufferElements = 0
context.become(flushing)
isShuttingDown = true
}
completeDownstream()
}
}
@ -400,13 +414,15 @@ private[akka] object TransformProcessorImpl {
*/
private[akka] class TransformProcessorImpl(_settings: GeneratorSettings, op: Ast.Transform) extends ActorProcessorImpl(_settings) {
var state = op.zero
var isComplete = false
// TODO performance improvement: mutable buffer?
var emits = immutable.Seq.empty[Any]
override def transfer(current: TransferState): TransferState = {
val depleted = current.inputsDepleted
if (emits.isEmpty) {
if (depleted) {
isComplete = op.isComplete(state)
if (depleted || isComplete) {
emits = op.onComplete(state)
} else {
val e = dequeueInputElement()
@ -420,7 +436,7 @@ private[akka] class TransformProcessorImpl(_settings: GeneratorSettings, op: Ast
}
if (emits.nonEmpty) NeedsDemand
else if (depleted) Completed
else if (depleted || isComplete) Completed
else NeedsInputAndDemand
}
}

View file

@ -25,7 +25,7 @@ class IdentityProcessorTest extends IdentityProcessorVerification[Int] with With
// FIXME can we use API to create the IdentityProcessor instead?
def identityProps(settings: GeneratorSettings): Props =
Props(new TransformProcessorImpl(settings, Ast.Transform(Unit, (_, in: Any) (Unit, List(in)), (_: Any) Nil)))
Props(new TransformProcessorImpl(settings, Ast.Transform(Unit, (_, in: Any) (Unit, List(in)), (_: Any) Nil, (_: Any) false)))
val actor = system.actorOf(identityProps(
GeneratorSettings(

View file

@ -0,0 +1,30 @@
/**
* Copyright (C) 2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream
import akka.testkit.AkkaSpec
import akka.stream.testkit.ScriptedTest
import scala.concurrent.forkjoin.ThreadLocalRandom.{ current random }
class StreamTakeSpec extends AkkaSpec with ScriptedTest {
val genSettings = GeneratorSettings(
initialInputBufferSize = 2,
maximumInputBufferSize = 16,
initialFanOutBufferSize = 1,
maxFanOutBufferSize = 16)
"A Take" must {
"take" in {
def script(d: Int) = Script((1 to 50) map { n Seq(n) -> (if (n > d) Nil else Seq(n)) }: _*)
(1 to 50) foreach { _
val d = Math.min(Math.max(random.nextInt(-10, 60), 0), 50)
runScript(script(d), genSettings)(_.take(d))
}
}
}
}

View file

@ -115,6 +115,41 @@ class StreamTransformSpec extends AkkaSpec {
c.expectComplete()
}
"allow cancellation using isComplete" in {
val p = StreamTestKit.producerProbe[Int]
val p2 = Stream(p).transform("")((s, in) (s + in, List(in)), isComplete = _ == "1").toProducer(gen)
val proc = p.expectSubscription
val c = StreamTestKit.consumerProbe[Int]
p2.produceTo(c)
val s = c.expectSubscription()
s.requestMore(10)
proc.expectRequestMore(2)
proc.sendNext(1)
proc.sendNext(2)
proc.expectRequestMore(1)
c.expectNext(1)
c.expectComplete()
proc.expectCancellation()
}
"call onComplete after isComplete signaled completion" in {
val p = StreamTestKit.producerProbe[Int]
val p2 = Stream(p).transform("")((s, in) (s + in, List(in)), onComplete = x List(x.size + 10), isComplete = _ == "1").toProducer(gen)
val proc = p.expectSubscription
val c = StreamTestKit.consumerProbe[Int]
p2.produceTo(c)
val s = c.expectSubscription()
s.requestMore(10)
proc.expectRequestMore(2)
proc.sendNext(1)
proc.sendNext(2)
proc.expectRequestMore(1)
c.expectNext(1)
c.expectNext(11)
c.expectComplete()
proc.expectCancellation()
}
"report error when exception is thrown" in {
val p = new IteratorProducer(List(1, 2, 3).iterator)
val p2 = Stream(p).