diff --git a/akka-stream/src/main/scala/akka/stream/Stream.scala b/akka-stream/src/main/scala/akka/stream/Stream.scala index c978f13254..a001bed364 100644 --- a/akka-stream/src/main/scala/akka/stream/Stream.scala +++ b/akka-stream/src/main/scala/akka/stream/Stream.scala @@ -22,9 +22,13 @@ trait Stream[T] { def foreach(c: T ⇒ Unit): Stream[Unit] def fold[U](zero: U)(f: (U, T) ⇒ U): Stream[U] def drop(n: Int): Stream[T] + def take(n: Int): Stream[T] def grouped(n: Int): Stream[immutable.Seq[T]] def mapConcat[U](f: T ⇒ immutable.Seq[U]): Stream[U] - def transform[S, U](zero: S)(f: (S, T) ⇒ (S, immutable.Seq[U]), onComplete: S ⇒ immutable.Seq[U] = (_: S) ⇒ Nil): Stream[U] + def transform[S, U](zero: S)( + f: (S, T) ⇒ (S, immutable.Seq[U]), + onComplete: S ⇒ immutable.Seq[U] = (_: S) ⇒ Nil, + isComplete: S ⇒ Boolean = (_: S) ⇒ false): Stream[U] def toProducer(generator: ProcessorGenerator): Producer[T] } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamImpl.scala index 89ed2ab2d6..6d0efd7e74 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamImpl.scala @@ -12,6 +12,7 @@ import scala.util.control.NonFatal import akka.stream.{ Stream, GeneratorSettings, ProcessorGenerator } import scala.collection.immutable import akka.actor.ActorLogging +import java.util.Arrays /** * INTERNAL API @@ -19,7 +20,7 @@ import akka.actor.ActorLogging private[akka] object Ast { trait AstNode - case class Transform(zero: Any, f: (Any, Any) ⇒ (Any, immutable.Seq[Any]), onComplete: Any ⇒ immutable.Seq[Any]) extends AstNode + case class Transform(zero: Any, f: (Any, Any) ⇒ (Any, immutable.Seq[Any]), onComplete: Any ⇒ immutable.Seq[Any], isComplete: Any ⇒ Boolean) extends AstNode } /** @@ -40,6 +41,8 @@ private[akka] case class StreamImpl[I, O](producer: Producer[I], ops: List[Ast.A def drop(n: Int): Stream[O] = transform(n)((x, in) ⇒ if (x == 0) 0 -> List(in) else (x - 1) -> Nil) + def take(n: Int): Stream[O] = transform(n)((x, in) ⇒ if (x == 0) 0 -> Nil else (x - 1) -> List(in), isComplete = _ == 0) + def grouped(n: Int): Stream[immutable.Seq[O]] = transform(immutable.Seq.empty[O])((buf, in) ⇒ { val group = buf :+ in @@ -49,11 +52,15 @@ private[akka] case class StreamImpl[I, O](producer: Producer[I], ops: List[Ast.A def mapConcat[U](f: O ⇒ immutable.Seq[U]): Stream[U] = transform(())((_, in) ⇒ ((), f(in))) - def transform[S, U](zero: S)(f: (S, O) ⇒ (S, immutable.Seq[U]), onComplete: S ⇒ immutable.Seq[U] = (_: S) ⇒ Nil): Stream[U] = + def transform[S, U](zero: S)( + f: (S, O) ⇒ (S, immutable.Seq[U]), + onComplete: S ⇒ immutable.Seq[U] = (_: S) ⇒ Nil, + isComplete: S ⇒ Boolean = (_: S) ⇒ false): Stream[U] = andThen(Transform( zero, f.asInstanceOf[(Any, Any) ⇒ (Any, immutable.Seq[Any])], - onComplete.asInstanceOf[Any ⇒ immutable.Seq[Any]])) + onComplete.asInstanceOf[Any ⇒ immutable.Seq[Any]], + isComplete.asInstanceOf[Any ⇒ Boolean])) def toProducer(generator: ProcessorGenerator): Producer[O] = generator.toProducer(producer, ops) } @@ -266,7 +273,7 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings) } private var downstreamBufferSpace = 0 - private var inputBuffer = Array.ofDim[Any](settings.initialInputBufferSize) + private var inputBuffer = Array.ofDim[AnyRef](settings.initialInputBufferSize) private var inputBufferElements = 0 private var nextInputElementCursor = 0 val IndexMask = settings.initialInputBufferSize - 1 @@ -276,7 +283,8 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings) private var batchRemaining = requestBatchSize def dequeueInputElement(): Any = { - val elem = inputBuffer(nextInputElementCursor & IndexMask) + val elem = inputBuffer(nextInputElementCursor) + inputBuffer(nextInputElementCursor) = null batchRemaining -= 1 if (batchRemaining == 0 && !upstreamCompleted) { @@ -291,7 +299,7 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings) } def enqueueInputElement(elem: Any): Unit = { - inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem + inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef] inputBufferElements += 1 } @@ -336,7 +344,13 @@ private[akka] abstract class ActorProcessorImpl(val settings: GeneratorSettings) } catch { case NonFatal(e) ⇒ fail(e) } if (transferState.isCompleted) { - isShuttingDown = true + if (!isShuttingDown) { + if (!upstreamCompleted) upstream.cancel() + Arrays.fill(inputBuffer, nextInputElementCursor, nextInputElementCursor + inputBufferElements, null) + inputBufferElements = 0 + context.become(flushing) + isShuttingDown = true + } completeDownstream() } } @@ -400,13 +414,15 @@ private[akka] object TransformProcessorImpl { */ private[akka] class TransformProcessorImpl(_settings: GeneratorSettings, op: Ast.Transform) extends ActorProcessorImpl(_settings) { var state = op.zero + var isComplete = false // TODO performance improvement: mutable buffer? var emits = immutable.Seq.empty[Any] override def transfer(current: TransferState): TransferState = { val depleted = current.inputsDepleted if (emits.isEmpty) { - if (depleted) { + isComplete = op.isComplete(state) + if (depleted || isComplete) { emits = op.onComplete(state) } else { val e = dequeueInputElement() @@ -420,7 +436,7 @@ private[akka] class TransformProcessorImpl(_settings: GeneratorSettings, op: Ast } if (emits.nonEmpty) NeedsDemand - else if (depleted) Completed + else if (depleted || isComplete) Completed else NeedsInputAndDemand } } \ No newline at end of file diff --git a/akka-stream/src/test/scala/akka/stream/IdentityProcessorTest.scala b/akka-stream/src/test/scala/akka/stream/IdentityProcessorTest.scala index 26a2b792e6..2875483276 100644 --- a/akka-stream/src/test/scala/akka/stream/IdentityProcessorTest.scala +++ b/akka-stream/src/test/scala/akka/stream/IdentityProcessorTest.scala @@ -25,7 +25,7 @@ class IdentityProcessorTest extends IdentityProcessorVerification[Int] with With // FIXME can we use API to create the IdentityProcessor instead? def identityProps(settings: GeneratorSettings): Props = - Props(new TransformProcessorImpl(settings, Ast.Transform(Unit, (_, in: Any) ⇒ (Unit, List(in)), (_: Any) ⇒ Nil))) + Props(new TransformProcessorImpl(settings, Ast.Transform(Unit, (_, in: Any) ⇒ (Unit, List(in)), (_: Any) ⇒ Nil, (_: Any) ⇒ false))) val actor = system.actorOf(identityProps( GeneratorSettings( diff --git a/akka-stream/src/test/scala/akka/stream/StreamTakeSpec.scala b/akka-stream/src/test/scala/akka/stream/StreamTakeSpec.scala new file mode 100644 index 0000000000..c0c325d011 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/StreamTakeSpec.scala @@ -0,0 +1,30 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream + +import akka.testkit.AkkaSpec +import akka.stream.testkit.ScriptedTest +import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } + +class StreamTakeSpec extends AkkaSpec with ScriptedTest { + + val genSettings = GeneratorSettings( + initialInputBufferSize = 2, + maximumInputBufferSize = 16, + initialFanOutBufferSize = 1, + maxFanOutBufferSize = 16) + + "A Take" must { + + "take" in { + def script(d: Int) = Script((1 to 50) map { n ⇒ Seq(n) -> (if (n > d) Nil else Seq(n)) }: _*) + (1 to 50) foreach { _ ⇒ + val d = Math.min(Math.max(random.nextInt(-10, 60), 0), 50) + runScript(script(d), genSettings)(_.take(d)) + } + } + + } + +} \ No newline at end of file diff --git a/akka-stream/src/test/scala/akka/stream/StreamTransformSpec.scala b/akka-stream/src/test/scala/akka/stream/StreamTransformSpec.scala index 298335eeb5..56f25d22ad 100644 --- a/akka-stream/src/test/scala/akka/stream/StreamTransformSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/StreamTransformSpec.scala @@ -115,6 +115,41 @@ class StreamTransformSpec extends AkkaSpec { c.expectComplete() } + "allow cancellation using isComplete" in { + val p = StreamTestKit.producerProbe[Int] + val p2 = Stream(p).transform("")((s, in) ⇒ (s + in, List(in)), isComplete = _ == "1").toProducer(gen) + val proc = p.expectSubscription + val c = StreamTestKit.consumerProbe[Int] + p2.produceTo(c) + val s = c.expectSubscription() + s.requestMore(10) + proc.expectRequestMore(2) + proc.sendNext(1) + proc.sendNext(2) + proc.expectRequestMore(1) + c.expectNext(1) + c.expectComplete() + proc.expectCancellation() + } + + "call onComplete after isComplete signaled completion" in { + val p = StreamTestKit.producerProbe[Int] + val p2 = Stream(p).transform("")((s, in) ⇒ (s + in, List(in)), onComplete = x ⇒ List(x.size + 10), isComplete = _ == "1").toProducer(gen) + val proc = p.expectSubscription + val c = StreamTestKit.consumerProbe[Int] + p2.produceTo(c) + val s = c.expectSubscription() + s.requestMore(10) + proc.expectRequestMore(2) + proc.sendNext(1) + proc.sendNext(2) + proc.expectRequestMore(1) + c.expectNext(1) + c.expectNext(11) + c.expectComplete() + proc.expectCancellation() + } + "report error when exception is thrown" in { val p = new IteratorProducer(List(1, 2, 3).iterator) val p2 = Stream(p).