diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTakeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTakeSpec.scala index 04d6023e9e..d5e293ec74 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTakeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTakeSpec.scala @@ -3,6 +3,8 @@ */ package akka.stream.scaladsl +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } import akka.stream.ActorMaterializer @@ -38,6 +40,11 @@ class FlowTakeSpec extends AkkaSpec with ScriptedTest { probe.expectComplete() } + "complete eagerly when zero or less is taken independently of upstream completion" in { + Await.result(Source.lazyEmpty.take(0).runWith(Sink.ignore), 3.second) + Await.result(Source.lazyEmpty.take(-1).runWith(Sink.ignore), 3.second) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 679f7a643a..5893adf535 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -151,7 +151,7 @@ private[akka] final case class MapConcat[In, Out](f: In ⇒ immutable.Iterable[O /** * INTERNAL API */ -private[akka] final case class Take[T](count: Long) extends PushStage[T, T] { +private[akka] final case class Take[T](count: Long) extends PushPullStage[T, T] { private var left: Long = count override def onPush(elem: T, ctx: Context[T]): SyncDirective = { @@ -160,6 +160,10 @@ private[akka] final case class Take[T](count: Long) extends PushStage[T, T] { else if (left == 0) ctx.pushAndFinish(elem) else ctx.finish() //Handle negative take counts } + + override def onPull(ctx: Context[T]): SyncDirective = + if (left <= 0) ctx.finish() + else ctx.pull() } /**