diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FramingSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FramingSpec.scala index ce5ee6e9de..58b57fca3f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FramingSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FramingSpec.scala @@ -7,6 +7,7 @@ import java.nio.ByteOrder import akka.stream.scaladsl.Framing.FramingException import akka.stream.stage.{ Context, PushPullStage, SyncDirective, TerminationDirective } +import akka.stream.testkit.{ TestSubscriber, TestPublisher } import akka.testkit.AkkaSpec import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } import akka.util.{ ByteString, ByteStringBuilder } @@ -180,6 +181,33 @@ class FramingSpec extends AkkaSpec { 3.seconds) should ===(Vector.empty) } + "work with grouped frames" in { + val groupSize = 5 + val single = encode(referenceChunk.take(100), 0, 1, ByteOrder.BIG_ENDIAN) + val groupedFrames = (1 to groupSize) + .map(_ ⇒ single) + .fold(ByteString.empty)((result, bs) ⇒ result ++ bs) + + val publisher = TestPublisher.probe[ByteString]() + val substriber = TestSubscriber.manualProbe[ByteString]() + + Source.fromPublisher(publisher) + .via(Framing.lengthField(1, 0, Int.MaxValue, ByteOrder.BIG_ENDIAN)) + .to(Sink.fromSubscriber(substriber)) + .run() + + val subscription = substriber.expectSubscription() + + publisher.sendNext(groupedFrames) + publisher.sendComplete() + for (_ ← 1 to groupSize) { + subscription.request(1) + substriber.expectNext(single) + } + substriber.expectComplete() + subscription.cancel() + } + "report oversized frames" in { an[FramingException] should be thrownBy { Await.result( diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Framing.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Framing.scala index bb99e02666..b6527ad59f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Framing.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Framing.scala @@ -53,7 +53,7 @@ object Framing { maximumFrameLength: Int, byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN): Flow[ByteString, ByteString, NotUsed] = { require(fieldLength >= 1 && fieldLength <= 4, "Length field length must be 1, 2, 3 or 4.") - Flow[ByteString].transform(() ⇒ new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder)) + Flow[ByteString].via(new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder)) .named("lengthFieldFraming") } @@ -212,52 +212,77 @@ object Framing { val lengthFieldLength: Int, val lengthFieldOffset: Int, val maximumFrameLength: Int, - val byteOrder: ByteOrder) extends PushPullStage[ByteString, ByteString] { - private var buffer = ByteString.empty - private var frameSize = Int.MaxValue + val byteOrder: ByteOrder) extends GraphStage[FlowShape[ByteString, ByteString]] { private val minimumChunkSize = lengthFieldOffset + lengthFieldLength private val intDecoder = byteOrder match { case ByteOrder.BIG_ENDIAN ⇒ bigEndianDecoder case ByteOrder.LITTLE_ENDIAN ⇒ littleEndianDecoder } - private def tryPull(ctx: Context[ByteString]): SyncDirective = - if (ctx.isFinishing) ctx.fail(new FramingException("Stream finished but there was a truncated final frame in the buffer")) - else ctx.pull() + val in = Inlet[ByteString]("LengthFieldFramingStage.in") + val out = Outlet[ByteString]("LengthFieldFramingStage.out") + override val shape: FlowShape[ByteString, ByteString] = FlowShape(in, out) - override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = { - buffer ++= chunk - doParse(ctx) - } + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + private var buffer = ByteString.empty + private var frameSize = Int.MaxValue - override def onPull(ctx: Context[ByteString]): SyncDirective = doParse(ctx) - - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = - if (buffer.nonEmpty) ctx.absorbTermination() - else ctx.finish() - - private def doParse(ctx: Context[ByteString]): SyncDirective = { - def emitFrame(ctx: Context[ByteString]): SyncDirective = { - val parsedFrame = buffer.take(frameSize).compact + /** + * push, and reset frameSize and buffer + * + */ + private def pushFrame() = { + val emit = buffer.take(frameSize).compact buffer = buffer.drop(frameSize) frameSize = Int.MaxValue - if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame) - else ctx.push(parsedFrame) + push(out, emit) + if (buffer.isEmpty && isClosed(in)) { + completeStage() + } } - val bufSize = buffer.size - if (bufSize >= frameSize) emitFrame(ctx) - else if (bufSize >= minimumChunkSize) { - val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength) - frameSize = parsedLength + minimumChunkSize - if (frameSize > maximumFrameLength) - ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize")) - else if (bufSize >= frameSize) emitFrame(ctx) - else tryPull(ctx) - } else tryPull(ctx) - } + /** + * try to push downstream, if failed then try to pull upstream + * + */ + private def tryPushFrame() = { + val buffSize = buffer.size + if (buffSize >= frameSize) { + pushFrame() + } else if (buffSize >= minimumChunkSize) { + val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength) + frameSize = parsedLength + minimumChunkSize + if (frameSize > maximumFrameLength) { + failStage(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize")) + } else if (buffSize >= frameSize) { + pushFrame() + } else tryPull() + } else tryPull() + } - override def postStop(): Unit = buffer = null + private def tryPull() = { + if (isClosed(in)) { + failStage(new FramingException("Stream finished but there was a truncated final frame in the buffer")) + } else pull(in) + } + + override def onPush(): Unit = { + buffer ++= grab(in) + tryPushFrame() + } + + override def onPull() = tryPushFrame() + + override def onUpstreamFinish(): Unit = { + if (buffer.isEmpty) { + completeStage() + } else if (isAvailable(out)) { + tryPushFrame() + } // else swallow the termination and wait for pull + } + + setHandlers(in, out, this) + } } } diff --git a/project/MiMa.scala b/project/MiMa.scala index 6281cee41f..c9b0e291be 100644 --- a/project/MiMa.scala +++ b/project/MiMa.scala @@ -845,6 +845,13 @@ object MiMa extends AutoPlugin { ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.saveStateSnapshot"), ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout"), ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout_="), + + // #20345 converts LengthFieldFramingStage to GraphStage + ProblemFilters.exclude[MissingTypesProblem]("akka.stream.scaladsl.Framing$LengthFieldFramingStage"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPush"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onUpstreamFinish"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPull"), + ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.postStop") // #19834 ProblemFilters.exclude[MissingTypesProblem]("akka.stream.extra.Timed$StartTimed"),