diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala index 629c5c3ba5..3b9b8465b1 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala @@ -6,7 +6,6 @@ package akka.http.impl.util import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import akka.NotUsed -import akka.http.scaladsl.model.RequestEntity import akka.stream._ import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule } @@ -27,40 +26,29 @@ private[http] object StreamUtils { * input has been read it will call `finish` once to determine the final ByteString to post to the output. * Empty ByteStrings are discarded. */ - def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Stage[ByteString, ByteString] = { - new PushPullStage[ByteString, ByteString] { - override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = { - val data = f(element) - if (data.nonEmpty) ctx.push(data) - else ctx.pull() + def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): GraphStage[FlowShape[ByteString, ByteString]] = new SimpleLinearGraphStage[ByteString] { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPush(): Unit = { + val data = f(grab(in)) + if (data.nonEmpty) push(out, data) + else pull(in) } - override def onPull(ctx: Context[ByteString]): SyncDirective = - if (ctx.isFinishing) { - val data = finish() - if (data.nonEmpty) ctx.pushAndFinish(data) - else ctx.finish() - } else ctx.pull() + override def onPull(): Unit = pull(in) - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() + override def onUpstreamFinish(): Unit = { + val data = finish() + if (data.nonEmpty) emit(out, data) + completeStage() + } + + setHandlers(in, out, this) } } def failedPublisher[T](ex: Throwable): Publisher[T] = impl.ErrorPublisher(ex, "failed").asInstanceOf[Publisher[T]] - def mapErrorTransformer(f: Throwable ⇒ Throwable): Flow[ByteString, ByteString, NotUsed] = { - val transformer = new PushStage[ByteString, ByteString] { - override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = - ctx.push(element) - - override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = - ctx.fail(f(cause)) - } - - Flow[ByteString].transform(() ⇒ transformer).named("transformError") - } - def captureTermination[T, Mat](source: Source[T, Mat]): (Source[T, Mat], Future[Unit]) = { val promise = Promise[Unit]() val transformer = new SimpleLinearGraphStage[T] { @@ -85,37 +73,32 @@ private[http] object StreamUtils { } def sliceBytesTransformer(start: Long, length: Long): Flow[ByteString, ByteString, NotUsed] = { - val transformer = new StatefulStage[ByteString, ByteString] { + val transformer = new SimpleLinearGraphStage[ByteString] { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPull() = pull(in) - def skipping = new State { var toSkip = start + var remaining = length - override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = - if (element.length < toSkip) { + override def onPush(): Unit = { + val element = grab(in) + if (toSkip > 0 && element.length < toSkip) { // keep skipping toSkip -= element.length - ctx.pull() + pull(in) } else { - become(taking(length)) // toSkip <= element.length <= Int.MaxValue - current.onPush(element.drop(toSkip.toInt), ctx) + val data = element.drop(toSkip.toInt).take(math.min(remaining, Int.MaxValue).toInt) + remaining -= data.size + push(out, data) + if (remaining <= 0) completeStage() } - } - - def taking(initiallyRemaining: Long) = new State { - var remaining: Long = initiallyRemaining - - override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = { - val data = element.take(math.min(remaining, Int.MaxValue).toInt) - remaining -= data.size - if (remaining <= 0) ctx.pushAndFinish(data) - else ctx.push(data) } - } - override def initial: State = if (start > 0) skipping else taking(length) + setHandlers(in, out, this) + } } - Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes") + Flow[ByteString].via(transformer).named("sliceBytes") } def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] = @@ -162,9 +145,6 @@ private[http] object StreamUtils { } } - def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity = - _.transformDataBytes(mapErrorTransformer(f)) - /** * Returns a source that can only be used once for testing purposes. */ diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/model/HttpEntitySpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/model/HttpEntitySpec.scala index c1653657a4..9290f71655 100755 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/model/HttpEntitySpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/model/HttpEntitySpec.scala @@ -204,7 +204,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll { } def duplicateBytesTransformer(): Flow[ByteString, ByteString, NotUsed] = - Flow[ByteString].transform(() ⇒ StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer)) + Flow[ByteString].via(StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer)) def trailer: ByteString = ByteString("--dup") def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b ⇒ Seq(b, b)): _*) diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala index f897ff0db4..4018c8a4e6 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala @@ -7,7 +7,8 @@ package akka.http.scaladsl.coding import akka.NotUsed import akka.http.scaladsl.model._ import akka.http.impl.util.StreamUtils -import akka.stream.stage.Stage +import akka.stream.FlowShape +import akka.stream.stage.GraphStage import akka.util.ByteString import headers._ import akka.stream.scaladsl.Flow @@ -23,15 +24,15 @@ trait Encoder { else message.self def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T = - mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer)) + mapper.transformDataBytes(t, Flow[ByteString].via(newEncodeTransformer)) def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input) - def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].transform(newEncodeTransformer) + def encoderFlow: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].via(newEncodeTransformer) def newCompressor: Compressor - def newEncodeTransformer(): Stage[ByteString, ByteString] = { + def newEncodeTransformer(): GraphStage[FlowShape[ByteString, ByteString]] = { val compressor = newCompressor def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes) diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala index 1b538e4047..ed9eb58657 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala @@ -7,10 +7,13 @@ package directives import scala.collection.immutable import scala.util.control.NonFatal -import akka.http.scaladsl.model.headers.{ HttpEncodings, HttpEncoding } +import akka.http.scaladsl.model.headers.{ HttpEncoding, HttpEncodings } import akka.http.scaladsl.model._ import akka.http.scaladsl.coding._ import akka.http.impl.util._ +import akka.stream.impl.fusing.GraphStages +import akka.stream.scaladsl.Flow +import akka.util.ByteString /** * @groupname coding Coding directives @@ -80,12 +83,14 @@ trait CodingDirectives { extractSettings flatMap { settings ⇒ val effectiveDecoder = decoder.withMaxBytesPerChunk(settings.decodeMaxBytesPerChunk) mapRequest { request ⇒ - effectiveDecoder.decode(request).mapEntity(StreamUtils.mapEntityError { - case NonFatal(e) ⇒ - IllegalRequestException( - StatusCodes.BadRequest, - ErrorInfo("The request's encoding is corrupt", e.getMessage)) - }) + effectiveDecoder.decode(request).mapEntity { entity ⇒ + entity.transformDataBytes(Flow[ByteString].recover { + case NonFatal(e) ⇒ + throw IllegalRequestException( + StatusCodes.BadRequest, + ErrorInfo("The request's encoding is corrupt", e.getMessage)) + }) + } } }