htc #19834 convert streamutils to GraphStages graphstage
* =htc #19834 replace byteStringTransformer with GraphStage implementation * =htc #19834 replace mapErrorTransformer with GraphStage implementation * htc #19834 replace sliceBytesTransformer with GraphStage implementation * htc #19834 inline mapErrorTransformer with simple recover Flow
This commit is contained in:
parent
ff47c854dc
commit
d9f82d71eb
4 changed files with 47 additions and 61 deletions
|
|
@ -6,7 +6,6 @@ package akka.http.impl.util
|
|||
|
||||
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
|
||||
import akka.NotUsed
|
||||
import akka.http.scaladsl.model.RequestEntity
|
||||
import akka.stream._
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
|
||||
|
|
@ -27,40 +26,29 @@ private[http] object StreamUtils {
|
|||
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
|
||||
* Empty ByteStrings are discarded.
|
||||
*/
|
||||
def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Stage[ByteString, ByteString] = {
|
||||
new PushPullStage[ByteString, ByteString] {
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
val data = f(element)
|
||||
if (data.nonEmpty) ctx.push(data)
|
||||
else ctx.pull()
|
||||
def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): GraphStage[FlowShape[ByteString, ByteString]] = new SimpleLinearGraphStage[ByteString] {
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
override def onPush(): Unit = {
|
||||
val data = f(grab(in))
|
||||
if (data.nonEmpty) push(out, data)
|
||||
else pull(in)
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective =
|
||||
if (ctx.isFinishing) {
|
||||
val data = finish()
|
||||
if (data.nonEmpty) ctx.pushAndFinish(data)
|
||||
else ctx.finish()
|
||||
} else ctx.pull()
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
val data = finish()
|
||||
if (data.nonEmpty) emit(out, data)
|
||||
completeStage()
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
|
||||
def failedPublisher[T](ex: Throwable): Publisher[T] =
|
||||
impl.ErrorPublisher(ex, "failed").asInstanceOf[Publisher[T]]
|
||||
|
||||
def mapErrorTransformer(f: Throwable ⇒ Throwable): Flow[ByteString, ByteString, NotUsed] = {
|
||||
val transformer = new PushStage[ByteString, ByteString] {
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
ctx.push(element)
|
||||
|
||||
override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective =
|
||||
ctx.fail(f(cause))
|
||||
}
|
||||
|
||||
Flow[ByteString].transform(() ⇒ transformer).named("transformError")
|
||||
}
|
||||
|
||||
def captureTermination[T, Mat](source: Source[T, Mat]): (Source[T, Mat], Future[Unit]) = {
|
||||
val promise = Promise[Unit]()
|
||||
val transformer = new SimpleLinearGraphStage[T] {
|
||||
|
|
@ -85,37 +73,32 @@ private[http] object StreamUtils {
|
|||
}
|
||||
|
||||
def sliceBytesTransformer(start: Long, length: Long): Flow[ByteString, ByteString, NotUsed] = {
|
||||
val transformer = new StatefulStage[ByteString, ByteString] {
|
||||
val transformer = new SimpleLinearGraphStage[ByteString] {
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
override def onPull() = pull(in)
|
||||
|
||||
def skipping = new State {
|
||||
var toSkip = start
|
||||
var remaining = length
|
||||
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
if (element.length < toSkip) {
|
||||
override def onPush(): Unit = {
|
||||
val element = grab(in)
|
||||
if (toSkip > 0 && element.length < toSkip) {
|
||||
// keep skipping
|
||||
toSkip -= element.length
|
||||
ctx.pull()
|
||||
pull(in)
|
||||
} else {
|
||||
become(taking(length))
|
||||
// toSkip <= element.length <= Int.MaxValue
|
||||
current.onPush(element.drop(toSkip.toInt), ctx)
|
||||
val data = element.drop(toSkip.toInt).take(math.min(remaining, Int.MaxValue).toInt)
|
||||
remaining -= data.size
|
||||
push(out, data)
|
||||
if (remaining <= 0) completeStage()
|
||||
}
|
||||
}
|
||||
|
||||
def taking(initiallyRemaining: Long) = new State {
|
||||
var remaining: Long = initiallyRemaining
|
||||
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
val data = element.take(math.min(remaining, Int.MaxValue).toInt)
|
||||
remaining -= data.size
|
||||
if (remaining <= 0) ctx.pushAndFinish(data)
|
||||
else ctx.push(data)
|
||||
}
|
||||
}
|
||||
|
||||
override def initial: State = if (start > 0) skipping else taking(length)
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
}
|
||||
Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes")
|
||||
Flow[ByteString].via(transformer).named("sliceBytes")
|
||||
}
|
||||
|
||||
def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] =
|
||||
|
|
@ -162,9 +145,6 @@ private[http] object StreamUtils {
|
|||
}
|
||||
}
|
||||
|
||||
def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity =
|
||||
_.transformDataBytes(mapErrorTransformer(f))
|
||||
|
||||
/**
|
||||
* Returns a source that can only be used once for testing purposes.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll {
|
|||
}
|
||||
|
||||
def duplicateBytesTransformer(): Flow[ByteString, ByteString, NotUsed] =
|
||||
Flow[ByteString].transform(() ⇒ StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer))
|
||||
Flow[ByteString].via(StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer))
|
||||
|
||||
def trailer: ByteString = ByteString("--dup")
|
||||
def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b ⇒ Seq(b, b)): _*)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ package akka.http.scaladsl.coding
|
|||
import akka.NotUsed
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util.StreamUtils
|
||||
import akka.stream.stage.Stage
|
||||
import akka.stream.FlowShape
|
||||
import akka.stream.stage.GraphStage
|
||||
import akka.util.ByteString
|
||||
import headers._
|
||||
import akka.stream.scaladsl.Flow
|
||||
|
|
@ -23,15 +24,15 @@ trait Encoder {
|
|||
else message.self
|
||||
|
||||
def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
|
||||
mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer))
|
||||
mapper.transformDataBytes(t, Flow[ByteString].via(newEncodeTransformer))
|
||||
|
||||
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
|
||||
|
||||
def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].transform(newEncodeTransformer)
|
||||
def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].via(newEncodeTransformer)
|
||||
|
||||
def newCompressor: Compressor
|
||||
|
||||
def newEncodeTransformer(): Stage[ByteString, ByteString] = {
|
||||
def newEncodeTransformer(): GraphStage[FlowShape[ByteString, ByteString]] = {
|
||||
val compressor = newCompressor
|
||||
|
||||
def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,13 @@ package directives
|
|||
|
||||
import scala.collection.immutable
|
||||
import scala.util.control.NonFatal
|
||||
import akka.http.scaladsl.model.headers.{ HttpEncodings, HttpEncoding }
|
||||
import akka.http.scaladsl.model.headers.{ HttpEncoding, HttpEncodings }
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.coding._
|
||||
import akka.http.impl.util._
|
||||
import akka.stream.impl.fusing.GraphStages
|
||||
import akka.stream.scaladsl.Flow
|
||||
import akka.util.ByteString
|
||||
|
||||
/**
|
||||
* @groupname coding Coding directives
|
||||
|
|
@ -80,12 +83,14 @@ trait CodingDirectives {
|
|||
extractSettings flatMap { settings ⇒
|
||||
val effectiveDecoder = decoder.withMaxBytesPerChunk(settings.decodeMaxBytesPerChunk)
|
||||
mapRequest { request ⇒
|
||||
effectiveDecoder.decode(request).mapEntity(StreamUtils.mapEntityError {
|
||||
case NonFatal(e) ⇒
|
||||
IllegalRequestException(
|
||||
StatusCodes.BadRequest,
|
||||
ErrorInfo("The request's encoding is corrupt", e.getMessage))
|
||||
})
|
||||
effectiveDecoder.decode(request).mapEntity { entity ⇒
|
||||
entity.transformDataBytes(Flow[ByteString].recover {
|
||||
case NonFatal(e) ⇒
|
||||
throw IllegalRequestException(
|
||||
StatusCodes.BadRequest,
|
||||
ErrorInfo("The request's encoding is corrupt", e.getMessage))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue