converts LengthFieldFramingStage to GraphStage

This commit is contained in:
zhxiaog 2016-04-17 23:01:46 +08:00 committed by Konrad Malawski
parent db83775c6d
commit 11584f12eb
3 changed files with 94 additions and 34 deletions

View file

@ -7,6 +7,7 @@ import java.nio.ByteOrder
import akka.stream.scaladsl.Framing.FramingException
import akka.stream.stage.{ Context, PushPullStage, SyncDirective, TerminationDirective }
import akka.stream.testkit.{ TestSubscriber, TestPublisher }
import akka.testkit.AkkaSpec
import akka.stream.{ ActorMaterializer, ActorMaterializerSettings }
import akka.util.{ ByteString, ByteStringBuilder }
@ -180,6 +181,33 @@ class FramingSpec extends AkkaSpec {
3.seconds) should ===(Vector.empty)
}
"work with grouped frames" in {
val groupSize = 5
val single = encode(referenceChunk.take(100), 0, 1, ByteOrder.BIG_ENDIAN)
val groupedFrames = (1 to groupSize)
.map(_ single)
.fold(ByteString.empty)((result, bs) result ++ bs)
val publisher = TestPublisher.probe[ByteString]()
val substriber = TestSubscriber.manualProbe[ByteString]()
Source.fromPublisher(publisher)
.via(Framing.lengthField(1, 0, Int.MaxValue, ByteOrder.BIG_ENDIAN))
.to(Sink.fromSubscriber(substriber))
.run()
val subscription = substriber.expectSubscription()
publisher.sendNext(groupedFrames)
publisher.sendComplete()
for (_ 1 to groupSize) {
subscription.request(1)
substriber.expectNext(single)
}
substriber.expectComplete()
subscription.cancel()
}
"report oversized frames" in {
an[FramingException] should be thrownBy {
Await.result(

View file

@ -53,7 +53,7 @@ object Framing {
maximumFrameLength: Int,
byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN): Flow[ByteString, ByteString, NotUsed] = {
require(fieldLength >= 1 && fieldLength <= 4, "Length field length must be 1, 2, 3 or 4.")
Flow[ByteString].transform(() new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder))
Flow[ByteString].via(new LengthFieldFramingStage(fieldLength, fieldOffset, maximumFrameLength, byteOrder))
.named("lengthFieldFraming")
}
@ -212,52 +212,77 @@ object Framing {
val lengthFieldLength: Int,
val lengthFieldOffset: Int,
val maximumFrameLength: Int,
val byteOrder: ByteOrder) extends PushPullStage[ByteString, ByteString] {
private var buffer = ByteString.empty
private var frameSize = Int.MaxValue
val byteOrder: ByteOrder) extends GraphStage[FlowShape[ByteString, ByteString]] {
private val minimumChunkSize = lengthFieldOffset + lengthFieldLength
private val intDecoder = byteOrder match {
case ByteOrder.BIG_ENDIAN bigEndianDecoder
case ByteOrder.LITTLE_ENDIAN littleEndianDecoder
}
private def tryPull(ctx: Context[ByteString]): SyncDirective =
if (ctx.isFinishing) ctx.fail(new FramingException("Stream finished but there was a truncated final frame in the buffer"))
else ctx.pull()
val in = Inlet[ByteString]("LengthFieldFramingStage.in")
val out = Outlet[ByteString]("LengthFieldFramingStage.out")
override val shape: FlowShape[ByteString, ByteString] = FlowShape(in, out)
override def onPush(chunk: ByteString, ctx: Context[ByteString]): SyncDirective = {
buffer ++= chunk
doParse(ctx)
}
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
private var buffer = ByteString.empty
private var frameSize = Int.MaxValue
override def onPull(ctx: Context[ByteString]): SyncDirective = doParse(ctx)
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
if (buffer.nonEmpty) ctx.absorbTermination()
else ctx.finish()
private def doParse(ctx: Context[ByteString]): SyncDirective = {
def emitFrame(ctx: Context[ByteString]): SyncDirective = {
val parsedFrame = buffer.take(frameSize).compact
/**
* push, and reset frameSize and buffer
*
*/
private def pushFrame() = {
val emit = buffer.take(frameSize).compact
buffer = buffer.drop(frameSize)
frameSize = Int.MaxValue
if (ctx.isFinishing && buffer.isEmpty) ctx.pushAndFinish(parsedFrame)
else ctx.push(parsedFrame)
push(out, emit)
if (buffer.isEmpty && isClosed(in)) {
completeStage()
}
}
val bufSize = buffer.size
if (bufSize >= frameSize) emitFrame(ctx)
else if (bufSize >= minimumChunkSize) {
val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength)
frameSize = parsedLength + minimumChunkSize
if (frameSize > maximumFrameLength)
ctx.fail(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize"))
else if (bufSize >= frameSize) emitFrame(ctx)
else tryPull(ctx)
} else tryPull(ctx)
}
/**
* try to push downstream, if failed then try to pull upstream
*
*/
private def tryPushFrame() = {
val buffSize = buffer.size
if (buffSize >= frameSize) {
pushFrame()
} else if (buffSize >= minimumChunkSize) {
val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength)
frameSize = parsedLength + minimumChunkSize
if (frameSize > maximumFrameLength) {
failStage(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize"))
} else if (buffSize >= frameSize) {
pushFrame()
} else tryPull()
} else tryPull()
}
override def postStop(): Unit = buffer = null
private def tryPull() = {
if (isClosed(in)) {
failStage(new FramingException("Stream finished but there was a truncated final frame in the buffer"))
} else pull(in)
}
override def onPush(): Unit = {
buffer ++= grab(in)
tryPushFrame()
}
override def onPull() = tryPushFrame()
override def onUpstreamFinish(): Unit = {
if (buffer.isEmpty) {
completeStage()
} else if (isAvailable(out)) {
tryPushFrame()
} // else swallow the termination and wait for pull
}
setHandlers(in, out, this)
}
}
}

View file

@ -845,6 +845,13 @@ object MiMa extends AutoPlugin {
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.saveStateSnapshot"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.persistence.fsm.PersistentFSM.akka$persistence$fsm$PersistentFSM$$currentStateTimeout_="),
// #20345 converts LengthFieldFramingStage to GraphStage
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.scaladsl.Framing$LengthFieldFramingStage"),
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPush"),
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onUpstreamFinish"),
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.onPull"),
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.scaladsl.Framing#LengthFieldFramingStage.postStop")
// #19834
ProblemFilters.exclude[MissingTypesProblem]("akka.stream.extra.Timed$StartTimed"),