htc #19834 convert streamutils to GraphStages graphstage

* =htc #19834 replace byteStringTransformer with GraphStage implementation

* =htc #19834 replace mapErrorTransformer with GraphStage implementation

* htc #19834 replace sliceBytesTransformer with GraphStage implementation

* htc #19834 inline mapErrorTransformer with simple recover Flow
This commit is contained in:
Bernard Leach 2016-04-05 22:02:40 +10:00 committed by Konrad Malawski
parent ff47c854dc
commit d9f82d71eb
4 changed files with 47 additions and 61 deletions

View file

@ -6,7 +6,6 @@ package akka.http.impl.util
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import akka.NotUsed
import akka.http.scaladsl.model.RequestEntity
import akka.stream._
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
@ -27,40 +26,29 @@ private[http] object StreamUtils {
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
* Empty ByteStrings are discarded.
*/
def byteStringTransformer(f: ByteString ByteString, finish: () ByteString): Stage[ByteString, ByteString] = {
new PushPullStage[ByteString, ByteString] {
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
val data = f(element)
if (data.nonEmpty) ctx.push(data)
else ctx.pull()
def byteStringTransformer(f: ByteString ByteString, finish: () ByteString): GraphStage[FlowShape[ByteString, ByteString]] = new SimpleLinearGraphStage[ByteString] {
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
override def onPush(): Unit = {
val data = f(grab(in))
if (data.nonEmpty) push(out, data)
else pull(in)
}
override def onPull(ctx: Context[ByteString]): SyncDirective =
if (ctx.isFinishing) {
val data = finish()
if (data.nonEmpty) ctx.pushAndFinish(data)
else ctx.finish()
} else ctx.pull()
override def onPull(): Unit = pull(in)
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
override def onUpstreamFinish(): Unit = {
val data = finish()
if (data.nonEmpty) emit(out, data)
completeStage()
}
setHandlers(in, out, this)
}
}
def failedPublisher[T](ex: Throwable): Publisher[T] =
impl.ErrorPublisher(ex, "failed").asInstanceOf[Publisher[T]]
def mapErrorTransformer(f: Throwable Throwable): Flow[ByteString, ByteString, NotUsed] = {
val transformer = new PushStage[ByteString, ByteString] {
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
ctx.push(element)
override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective =
ctx.fail(f(cause))
}
Flow[ByteString].transform(() transformer).named("transformError")
}
def captureTermination[T, Mat](source: Source[T, Mat]): (Source[T, Mat], Future[Unit]) = {
val promise = Promise[Unit]()
val transformer = new SimpleLinearGraphStage[T] {
@ -85,37 +73,32 @@ private[http] object StreamUtils {
}
def sliceBytesTransformer(start: Long, length: Long): Flow[ByteString, ByteString, NotUsed] = {
val transformer = new StatefulStage[ByteString, ByteString] {
val transformer = new SimpleLinearGraphStage[ByteString] {
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
override def onPull() = pull(in)
def skipping = new State {
var toSkip = start
var remaining = length
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
if (element.length < toSkip) {
override def onPush(): Unit = {
val element = grab(in)
if (toSkip > 0 && element.length < toSkip) {
// keep skipping
toSkip -= element.length
ctx.pull()
pull(in)
} else {
become(taking(length))
// toSkip <= element.length <= Int.MaxValue
current.onPush(element.drop(toSkip.toInt), ctx)
val data = element.drop(toSkip.toInt).take(math.min(remaining, Int.MaxValue).toInt)
remaining -= data.size
push(out, data)
if (remaining <= 0) completeStage()
}
}
def taking(initiallyRemaining: Long) = new State {
var remaining: Long = initiallyRemaining
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
val data = element.take(math.min(remaining, Int.MaxValue).toInt)
remaining -= data.size
if (remaining <= 0) ctx.pushAndFinish(data)
else ctx.push(data)
}
}
override def initial: State = if (start > 0) skipping else taking(length)
setHandlers(in, out, this)
}
}
Flow[ByteString].transform(() transformer).named("sliceBytes")
Flow[ByteString].via(transformer).named("sliceBytes")
}
def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] =
@ -162,9 +145,6 @@ private[http] object StreamUtils {
}
}
def mapEntityError(f: Throwable Throwable): RequestEntity RequestEntity =
_.transformDataBytes(mapErrorTransformer(f))
/**
* Returns a source that can only be used once for testing purposes.
*/

View file

@ -204,7 +204,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll {
}
def duplicateBytesTransformer(): Flow[ByteString, ByteString, NotUsed] =
Flow[ByteString].transform(() StreamUtils.byteStringTransformer(doubleChars, () trailer))
Flow[ByteString].via(StreamUtils.byteStringTransformer(doubleChars, () trailer))
def trailer: ByteString = ByteString("--dup")
def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b Seq(b, b)): _*)

View file

@ -7,7 +7,8 @@ package akka.http.scaladsl.coding
import akka.NotUsed
import akka.http.scaladsl.model._
import akka.http.impl.util.StreamUtils
import akka.stream.stage.Stage
import akka.stream.FlowShape
import akka.stream.stage.GraphStage
import akka.util.ByteString
import headers._
import akka.stream.scaladsl.Flow
@ -23,15 +24,15 @@ trait Encoder {
else message.self
def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T =
mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer))
mapper.transformDataBytes(t, Flow[ByteString].via(newEncodeTransformer))
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].transform(newEncodeTransformer)
def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].via(newEncodeTransformer)
def newCompressor: Compressor
def newEncodeTransformer(): Stage[ByteString, ByteString] = {
def newEncodeTransformer(): GraphStage[FlowShape[ByteString, ByteString]] = {
val compressor = newCompressor
def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes)

View file

@ -7,10 +7,13 @@ package directives
import scala.collection.immutable
import scala.util.control.NonFatal
import akka.http.scaladsl.model.headers.{ HttpEncodings, HttpEncoding }
import akka.http.scaladsl.model.headers.{ HttpEncoding, HttpEncodings }
import akka.http.scaladsl.model._
import akka.http.scaladsl.coding._
import akka.http.impl.util._
import akka.stream.impl.fusing.GraphStages
import akka.stream.scaladsl.Flow
import akka.util.ByteString
/**
* @groupname coding Coding directives
@ -80,12 +83,14 @@ trait CodingDirectives {
extractSettings flatMap { settings
val effectiveDecoder = decoder.withMaxBytesPerChunk(settings.decodeMaxBytesPerChunk)
mapRequest { request
effectiveDecoder.decode(request).mapEntity(StreamUtils.mapEntityError {
case NonFatal(e)
IllegalRequestException(
StatusCodes.BadRequest,
ErrorInfo("The request's encoding is corrupt", e.getMessage))
})
effectiveDecoder.decode(request).mapEntity { entity
entity.transformDataBytes(Flow[ByteString].recover {
case NonFatal(e)
throw IllegalRequestException(
StatusCodes.BadRequest,
ErrorInfo("The request's encoding is corrupt", e.getMessage))
})
}
}
}