diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index e2b535c812..9aa9127b91 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -14,11 +14,14 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider /* - * These tests were writtern for the previous veryion of the interpreter, the so called OneBoundedInterpreter. + * These tests were written for the previous version of the interpreter, the so called OneBoundedInterpreter. * These stages are now properly emulated by the GraphInterpreter and many of the edge cases were relevant to * the execution model of the old one. Still, these tests are very valuable, so please do not remove. */ + val takeOne = Take(1) + val takeTwo = Take(2) + "Interpreter" must { "implement map correctly" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -126,7 +129,7 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement take" in new OneBoundedSetup[Int](Seq(Take(2))) { + "implement take" in new OneBoundedSetup[Int](takeTwo) { lastEvents() should be(Set.empty) @@ -143,10 +146,10 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(1), Cancel, OnComplete)) } - "implement take inside a chain" in new OneBoundedSetup[Int](Seq( - Filter((x: Int) ⇒ x != 0, stoppingDecider), - Take(2), - Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "implement take inside a chain" in new OneBoundedSetup[Int]( + Filter((x: Int) ⇒ x != 0, stoppingDecider).toGS, + takeTwo, + Map((x: Int) ⇒ x + 1, stoppingDecider).toGS) { lastEvents() should be(Set.empty) @@ -521,9 +524,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { } should be(true) } - "implement take-take" in new OneBoundedSetup[Int](Seq( - Take(1), - Take(1))) { + "implement take-take" in new OneBoundedSetup[Int]( + takeOne, + takeOne) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -534,9 +537,9 @@ class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { } - "implement take-take with pushAndFinish from upstream" in new OneBoundedSetup[Int](Seq( - Take(1), - Take(1))) { + "implement take-take with pushAndFinish from upstream" in new OneBoundedSetup[Int]( + takeOne, + takeOne) { lastEvents() should be(Set.empty) downstream.requestOne() diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala index 99f2a20b33..93bd741b05 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala @@ -16,8 +16,10 @@ class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { val map = Map((x: Int) ⇒ x + 1, stoppingDecider).toGS - // GraphStage can be reused + // GraphStages can be reused val dropOne = Drop(1) + val takeOne = Take(1) + val takeHalfOfRepetition = Take(repetition / 2) "Interpreter" must { @@ -45,7 +47,7 @@ class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { "work with a massive chain of maps with early complete" in new OneBoundedSetup[Int]( Vector.fill(halfLength)(map) ++ - Seq(Take(repetition / 2).toGS) ++ + Seq(takeHalfOfRepetition) ++ Vector.fill(halfLength)(map): _*) { lastEvents() should be(Set.empty) @@ -72,7 +74,7 @@ class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { info(s"Chain finished in $time seconds ${(chainLength * repetition) / (time * 1000 * 1000)} million maps/s") } - "work with a massive chain of takes" in new OneBoundedSetup[Int](Vector.fill(chainLength / 10)(Take(1))) { + "work with a massive chain of takes" in new OneBoundedSetup[Int](Vector.fill(chainLength / 10)(takeOne): _*) { lastEvents() should be(Set.empty) downstream.requestOne() diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index ac9e344abd..d94114801d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -179,10 +179,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) } - final case class Take[T](n: Long, attributes: Attributes = take) extends SymbolicStage[T, T] { - override def create(attr: Attributes): Stage[T, T] = fusing.Take(n) - } - final case class TakeWhile[T](p: T ⇒ Boolean, attributes: Attributes = takeWhile) extends SymbolicStage[T, T] { override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 01fbc59a9a..6fb58721ed 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -124,19 +124,30 @@ private[akka] final case class Recover[T](pf: PartialFunction[Throwable, T]) ext /** * INTERNAL API */ -private[akka] final case class Take[T](count: Long) extends PushPullStage[T, T] { - private var left: Long = count +private[akka] final case class Take[T](count: Long) extends SimpleLinearGraphStage[T] { + override def initialAttributes: Attributes = DefaultAttributes.take - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - left -= 1 - if (left > 0) ctx.push(elem) - else if (left == 0) ctx.pushAndFinish(elem) - else ctx.finish() //Handle negative take counts + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + private var left: Long = count + + override def onPush(): Unit = { + val leftBefore = left + if (leftBefore >= 1) { + left = leftBefore - 1 + push(out, grab(in)) + } + if (leftBefore <= 1) completeStage() + } + + override def onPull(): Unit = { + if (left > 0) pull(in) + else completeStage() + } + + setHandlers(in, out, this) } - override def onPull(ctx: Context[T]): SyncDirective = - if (left <= 0) ctx.finish() - else ctx.pull() + override def toString: String = "Take" } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 8da794ede0..f362f10a67 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -923,7 +923,8 @@ trait FlowOps[+Out, +Mat] { * * See also [[FlowOps.limit]], [[FlowOps.limitWeighted]] */ - def take(n: Long): Repr[Out] = andThen(Take(n)) + def take(n: Long): Repr[Out] = + via(Take[Out](n)) /** * Terminate processing (and cancel the upstream publisher) after the given diff --git a/project/MiMa.scala b/project/MiMa.scala index adba135c98..732ce2e257 100644 --- a/project/MiMa.scala +++ b/project/MiMa.scala @@ -659,6 +659,13 @@ object MiMa extends AutoPlugin { ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.Drop.onPush"), ProblemFilters.exclude[FinalClassProblem]("akka.stream.stage.GraphStageLogic$Reading"), // this class is private + // #19908 Take is private + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Take$"), + ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Take"), + ProblemFilters.exclude[MissingTypesProblem]("akka.stream.impl.fusing.Take"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.Take.onPush"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.fusing.Take.onPull"), + // #19815 make HTTP compile under Scala 2.12.0-M3 ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.http.scaladsl.model.headers.CacheDirectives#private.apply"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.http.scaladsl.model.headers.CacheDirectives#no-cache.apply"),