converts LengthFieldFramingStage to GraphStage
This commit is contained in:
parent
db83775c6d
commit
11584f12eb
3 changed files with 94 additions and 34 deletions
|
|
@ -7,6 +7,7 @@ import java.nio.ByteOrder
|
|||
|
||||
import akka.stream.scaladsl.Framing.FramingException
|
||||
import akka.stream.stage.{ Context, PushPullStage, SyncDirective, TerminationDirective }
|
||||
import akka.stream.testkit.{ TestSubscriber, TestPublisher }
|
||||
import akka.testkit.AkkaSpec
|
||||
import akka.stream.{ ActorMaterializer, ActorMaterializerSettings }
|
||||
import akka.util.{ ByteString, ByteStringBuilder }
|
||||
|
|
@ -180,6 +181,33 @@ class FramingSpec extends AkkaSpec {
|
|||
3.seconds) should ===(Vector.empty)
|
||||
}
|
||||
|
||||
"work with grouped frames" in {
|
||||
val groupSize = 5
|
||||
val single = encode(referenceChunk.take(100), 0, 1, ByteOrder.BIG_ENDIAN)
|
||||
val groupedFrames = (1 to groupSize)
|
||||
.map(_ ⇒ single)
|
||||
.fold(ByteString.empty)((result, bs) ⇒ result ++ bs)
|
||||
|
||||
val publisher = TestPublisher.probe[ByteString]()
|
||||
val substriber = TestSubscriber.manualProbe[ByteString]()
|
||||
|
||||
Source.fromPublisher(publisher)
|
||||
.via(Framing.lengthField(1, 0, Int.MaxValue, ByteOrder.BIG_ENDIAN))
|
||||
.to(Sink.fromSubscriber(substriber))
|
||||
.run()
|
||||
|
||||
val subscription = substriber.expectSubscription()
|
||||
|
||||
publisher.sendNext(groupedFrames)
|
||||
publisher.sendComplete()
|
||||
for (_ ← 1 to groupSize) {
|
||||
subscription.request(1)
|
||||
substriber.expectNext(single)
|
||||
}
|
||||
substriber.expectComplete()
|
||||
subscription.cancel()
|
||||
}
|
||||
|
||||
"report oversized frames" in {
|
||||
an[FramingException] should be thrownBy {
|
||||
Await.result(
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ object Framing {
|
|||
maximumFrameLength: Int,
|
||||
byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN): Flow[ByteString, ByteString, NotUsed] = {
|
||||
require(fieldLength >= 1 && fieldLength <= 4, "Length field length must be 1, 2, 3 or 4.")
|
||||
Flow[ByteString].transform(() ⇒ new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder))
|
||||
Flow[ByteString].via(new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder))
|
||||
.named("lengthFieldFraming")
|
||||
}
|
||||
|
||||
|
|
@ -212,52 +212,77 @@ object Framing {
|
|||
val lengthFieldLength: Int,
|
||||
val lengthFieldOffset: Int,
|
||||
val maximumFrameLength: Int,
|
||||
val byteOrder: ByteOrder) extends PushPullStage[ByteString, ByteString] {
|
||||
private var buffer = ByteString.empty
|
||||
private var frameSize = Int.MaxValue
|
||||
val byteOrder: ByteOrder) extends GraphStage[FlowShape[ByteString, ByteString]] {
|
||||
private val minimumChunkSize = lengthFieldOffset + lengthFieldLength
|
||||
private val intDecoder = byteOrder match {
|
||||
case ByteOrder.BIG_ENDIAN ⇒ bigEndianDecoder
|
||||
case ByteOrder.LITTLE_ENDIAN ⇒ littleEndianDecoder
|
||||
}
|
||||
|
||||
private def tryPull(ctx: Context[ByteString]): SyncDirective =
|
||||
if (ctx.isFinishing) ctx.fail(new FramingException("Stream finished but there was a truncated final frame in the buffer"))
|
||||
else ctx.pull()
|
||||
val in = Inlet[ByteString]("LengthFieldFramingStage.in")
|
||||
val out = Outlet[ByteString]("LengthFieldFramingStage.out")
|
||||
override val shape: FlowShape[ByteString, ByteString] = FlowShape(in, out)
|
||||
|
||||
override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
buffer ++= chunk
|
||||
doParse(ctx)
|
||||
}
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
private var buffer = ByteString.empty
|
||||
private var frameSize = Int.MaxValue
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective = doParse(ctx)
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
|
||||
if (buffer.nonEmpty) ctx.absorbTermination()
|
||||
else ctx.finish()
|
||||
|
||||
private def doParse(ctx: Context[ByteString]): SyncDirective = {
|
||||
def emitFrame(ctx: Context[ByteString]): SyncDirective = {
|
||||
val parsedFrame = buffer.take(frameSize).compact
|
||||
/**
|
||||
* push, and reset frameSize and buffer
|
||||
*
|
||||
*/
|
||||
private def pushFrame() = {
|
||||
val emit = buffer.take(frameSize).compact
|
||||
buffer = buffer.drop(frameSize)
|
||||
frameSize = Int.MaxValue
|
||||
if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame)
|
||||
else ctx.push(parsedFrame)
|
||||
push(out, emit)
|
||||
if (buffer.isEmpty && isClosed(in)) {
|
||||
completeStage()
|
||||
}
|
||||
}
|
||||
|
||||
val bufSize = buffer.size
|
||||
if (bufSize >= frameSize) emitFrame(ctx)
|
||||
else if (bufSize >= minimumChunkSize) {
|
||||
val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength)
|
||||
frameSize = parsedLength + minimumChunkSize
|
||||
if (frameSize > maximumFrameLength)
|
||||
ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize"))
|
||||
else if (bufSize >= frameSize) emitFrame(ctx)
|
||||
else tryPull(ctx)
|
||||
} else tryPull(ctx)
|
||||
}
|
||||
/**
|
||||
* try to push downstream, if failed then try to pull upstream
|
||||
*
|
||||
*/
|
||||
private def tryPushFrame() = {
|
||||
val buffSize = buffer.size
|
||||
if (buffSize >= frameSize) {
|
||||
pushFrame()
|
||||
} else if (buffSize >= minimumChunkSize) {
|
||||
val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength)
|
||||
frameSize = parsedLength + minimumChunkSize
|
||||
if (frameSize > maximumFrameLength) {
|
||||
failStage(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize"))
|
||||
} else if (buffSize >= frameSize) {
|
||||
pushFrame()
|
||||
} else tryPull()
|
||||
} else tryPull()
|
||||
}
|
||||
|
||||
override def postStop(): Unit = buffer = null
|
||||
private def tryPull() = {
|
||||
if (isClosed(in)) {
|
||||
failStage(new FramingException("Stream finished but there was a truncated final frame in the buffer"))
|
||||
} else pull(in)
|
||||
}
|
||||
|
||||
override def onPush(): Unit = {
|
||||
buffer ++= grab(in)
|
||||
tryPushFrame()
|
||||
}
|
||||
|
||||
override def onPull() = tryPushFrame()
|
||||
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
if (buffer.isEmpty) {
|
||||
completeStage()
|
||||
} else if (isAvailable(out)) {
|
||||
tryPushFrame()
|
||||
} // else swallow the termination and wait for pull
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -845,6 +845,13 @@ object MiMa extends AutoPlugin {
|
|||
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.saveStateSnapshot"),
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout"),
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout_="),
|
||||
|
||||
// #20345 converts LengthFieldFramingStage to GraphStage
|
||||
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.scaladsl.Framing$LengthFieldFramingStage"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPush"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onUpstreamFinish"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPull"),
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.postStop")
|
||||
|
||||
// #19834
|
||||
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.extra.Timed$StartTimed"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue