diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index 07327b37d3..071fa4cbf0 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -341,13 +341,11 @@ class GraphInterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { "implement buffer" in new TestSetup { val source = new UpstreamProbe[String]("source") val sink = new DownstreamProbe[String]("sink") - val buffer = new PushPullGraphStage[String, String, NotUsed]( - (_) ⇒ new Buffer[String](2, OverflowStrategy.backpressure), - Attributes.none) + val buffer = Buffer[String](2, OverflowStrategy.backpressure) builder(buffer) - .connect(source, buffer.shape.in) - .connect(buffer.shape.out, sink) + .connect(source, buffer.in) + .connect(buffer.out, sink) .init() stepAll() diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index fc85edb97c..9eb341a33a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -155,9 +155,4 @@ object Stages { } - final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] { - require(size > 0, s"Buffer size must be larger than zero but was [$size]") - override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy) - } - } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 734b1fdd22..fdb668e365 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -843,61 +843,74 @@ final case class Sliding[T](val n: Int, val step: Int) extends GraphStage[FlowSh /** * INTERNAL API */ -final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy) extends DetachedStage[T, T] { +final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy) extends SimpleLinearGraphStage[T] { - private var buffer: BufferImpl[T] = _ + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { - override def preStart(ctx: LifecycleContext): Unit = { - buffer = BufferImpl(size, ctx.materializer) - } + private var buffer: BufferImpl[T] = _ - override def onPush(elem: T, ctx: DetachedContext[T]): UpstreamDirective = - if (ctx.isHoldingDownstream) ctx.pushAndPull(elem) - else enqueueAction(ctx, elem) - - override def onPull(ctx: DetachedContext[T]): DownstreamDirective = { - if (ctx.isFinishing) { - val elem = buffer.dequeue() - if (buffer.isEmpty) ctx.pushAndFinish(elem) - else ctx.push(elem) - } else if (ctx.isHoldingUpstream) ctx.pushAndPull(buffer.dequeue()) - else if (buffer.isEmpty) ctx.holdDownstream() - else ctx.push(buffer.dequeue()) - } - - override def onUpstreamFinish(ctx: DetachedContext[T]): TerminationDirective = - if (buffer.isEmpty) ctx.finish() - else ctx.absorbTermination() - - val enqueueAction: (DetachedContext[T], T) ⇒ UpstreamDirective = - overflowStrategy match { - case DropHead ⇒ (ctx, elem) ⇒ - if (buffer.isFull) buffer.dropHead() - buffer.enqueue(elem) - ctx.pull() - case DropTail ⇒ (ctx, elem) ⇒ - if (buffer.isFull) buffer.dropTail() - buffer.enqueue(elem) - ctx.pull() - case DropBuffer ⇒ (ctx, elem) ⇒ - if (buffer.isFull) buffer.clear() - buffer.enqueue(elem) - ctx.pull() - case DropNew ⇒ (ctx, elem) ⇒ - if (!buffer.isFull) buffer.enqueue(elem) - ctx.pull() - case Backpressure ⇒ (ctx, elem) ⇒ - buffer.enqueue(elem) - if (buffer.isFull) ctx.holdUpstream() - else ctx.pull() - case Fail ⇒ (ctx, elem) ⇒ - if (buffer.isFull) ctx.fail(new BufferOverflowException(s"Buffer overflow (max capacity was: $size)!")) - else { + val enqueueAction: T ⇒ Unit = + overflowStrategy match { + case DropHead ⇒ elem ⇒ + if (buffer.isFull) buffer.dropHead() buffer.enqueue(elem) - ctx.pull() - } + pull(in) + case DropTail ⇒ elem ⇒ + if (buffer.isFull) buffer.dropTail() + buffer.enqueue(elem) + pull(in) + case DropBuffer ⇒ elem ⇒ + if (buffer.isFull) buffer.clear() + buffer.enqueue(elem) + pull(in) + case DropNew ⇒ elem ⇒ + if (!buffer.isFull) buffer.enqueue(elem) + pull(in) + case Backpressure ⇒ elem ⇒ + buffer.enqueue(elem) + if (!buffer.isFull) pull(in) + case Fail ⇒ elem ⇒ + if (buffer.isFull) failStage(new BufferOverflowException(s"Buffer overflow (max capacity was: $size)!")) + else { + buffer.enqueue(elem) + pull(in) + } + } + + override def preStart(): Unit = { + buffer = BufferImpl(size, materializer) + pull(in) } + override def onPush(): Unit = { + val elem = grab(in) + // If out is available, then it has been pulled but no dequeued element has been delivered. + // It means the buffer at this moment is definitely empty, + // so we just push the current element to out, then pull. + if (isAvailable(out)) { + push(out, elem) + pull(in) + } else { + enqueueAction(elem) + } + } + + override def onPull(): Unit = { + if (buffer.nonEmpty) push(out, buffer.dequeue()) + if (isClosed(in)) { + if (buffer.isEmpty) completeStage() + } else if (!hasBeenPulled(in)) { + pull(in) + } + } + + override def onUpstreamFinish(): Unit = { + if (buffer.isEmpty) completeStage() + } + + setHandlers(in, out, this) + } + } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 6d16734c75..af82e0077b 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -1208,7 +1208,7 @@ trait FlowOps[+Out, +Mat] { * @param size The size of the buffer in element count * @param overflowStrategy Strategy that is used when incoming elements cannot fit inside the buffer */ - def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out] = andThen(Buffer(size, overflowStrategy)) + def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out] = via(fusing.Buffer(size, overflowStrategy)) /** * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]].