diff --git a/akka-docs-dev/rst/scala/http/common/http-model.rst b/akka-docs-dev/rst/scala/http/common/http-model.rst index 152a0b764b..d2d0e5c283 100644 --- a/akka-docs-dev/rst/scala/http/common/http-model.rst +++ b/akka-docs-dev/rst/scala/http/common/http-model.rst @@ -156,6 +156,43 @@ concrete subtype. Therefore you must make sure that you always consume the entity data, even in the case that you are not actually interested in it! + +Limiting message entity length +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +All message entities that Akka HTTP reads from the network automatically get a length verification check attached to +them. This check makes sure that the total entity size is less than or equal to the configured +``max-content-length`` [#]_, which is an important defense against certain Denial-of-Service attacks. +However, a single global limit for all requests (or responses) is often too inflexible for applications that need to +allow large limits for *some* requests (or responses) but want to clamp down on all messages not belonging into that +group. +In order to give you maximum flexibility in defining entity size limits according to your needs the ``HttpEntity`` +features a ``withSizeLimit`` method, which lets you adjust the globally configured maximum size for this particular +entity, be it to increase or decrease any previously set value. +This means that your application will receive all requests (or responses) from the HTTP layer, even the ones whose +``Content-Length`` exceeds the configured limit (because you might want to increase the limit yourself). +Only when the actual data stream ``Source`` contained in the entity is materialized will the boundary checks be +actually applied. In case the length verification fails the respective stream will be terminated with an +``EntityStreamException`` either directly at materialization time (if the ``Content-Length`` is known) or whenever more +data bytes than allowed have been read. + +When called on ``Strict`` entities the ``withSizeLimit`` method will return the entity itself if the length is within +the bound, otherwise a ``Default`` entity with a single element data stream. This allows for potential refinement of the +entity size limit at a later point (before materialization of the data stream). + +By default all message entities produced by the HTTP layer automatically carry the limit that is defined in the +application's ``max-content-length`` config setting. If the entity is transformed in a way that changes the +content-length and then another limit is applied then this new limit will be evaluated against the new +content-length. If the entity is transformed in a way that changes the content-length and no new limit is applied +then the previous limit will be applied against the previous content-length. +Generally this behavior should be in line with your expectations. + +.. [#] `akka.http.parsing.max-content-length` (applying to server- as well as client-side), + `akka.http.server.parsing.max-content-length` (server-side only), + `akka.http.client.parsing.max-content-length` (client-side only) or + `akka.http.host-connection-pool.client.parsing.max-content-length` (only host-connection-pools) + + Special processing for HEAD requests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index ccd0af28a3..12b94bfb90 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -7,7 +7,7 @@ package akka.http.impl.engine.client import language.existentials import scala.annotation.tailrec import scala.collection.mutable.ListBuffer -import akka.stream.io.{ SessionBytes, SslTlsInbound, SendBytes, SslTlsOutbound } +import akka.stream.io.{ SessionBytes, SslTlsInbound, SendBytes } import akka.util.ByteString import akka.event.LoggingAdapter import akka.stream._ @@ -71,7 +71,8 @@ private[http] object OutgoingConnectionBlueprint { .via(headAndTailFlow) .collect { case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts) ⇒ - HttpResponse(statusCode, headers, createEntity(entityParts), protocol) + val entity = createEntity(entityParts) withSizeLimit parserSettings.maxContentLength + HttpResponse(statusCode, headers, entity, protocol) case (MessageStartError(_, info), _) ⇒ throw IllegalResponseException(info) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala index 9aa3dbde5d..1947f58c7a 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala @@ -208,9 +208,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser } def parseChunkBody(chunkSize: Int, extension: String, cursor: Int): StateResult = - if (totalBytesRead + chunkSize > settings.maxContentLength) - failWithChunkedEntityTooLong(totalBytesRead + chunkSize) - else if (chunkSize > 0) { + if (chunkSize > 0) { val chunkBodyEnd = cursor + chunkSize def result(terminatorLen: Int) = { emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd).compact, extension))) @@ -285,7 +283,6 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser setCompletionHandling(CompletionOk) terminate() } - def failWithChunkedEntityTooLong(totalBytesRead: Long): StateResult def terminate(): StateResult = { terminated = true @@ -317,7 +314,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser case EntityPart(bytes) ⇒ bytes case EntityStreamError(info) ⇒ throw EntityStreamException(info) } - HttpEntity.Default(contentType(cth), contentLength, transformData(data)) + HttpEntity.Default(contentType(cth), contentLength, HttpEntity.limitableByteSource(transformData(data))) } def chunkedEntity(cth: Option[`Content-Type`], @@ -326,7 +323,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser case EntityChunk(chunk) ⇒ chunk case EntityStreamError(info) ⇒ throw EntityStreamException(info) } - HttpEntity.Chunked(contentType(cth), transformChunks(chunks)) + HttpEntity.Chunked(contentType(cth), HttpEntity.limitableChunkSource(transformChunks(chunks))) } def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index f69aebea9e..3301b13a68 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -160,11 +160,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, case Some(`Content-Length`(len)) ⇒ len case None ⇒ 0 } - if (contentLength > maxContentLength) - failMessageStart(RequestEntityTooLarge, - summary = s"Request Content-Length of $contentLength bytes exceeds the configured limit of $maxContentLength bytes", - detail = "Consider increasing the value of akka.http.server.parsing.max-content-length") - else if (contentLength == 0) { + if (contentLength == 0) { emitRequestStart(emptyEntity(cth)) setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart) @@ -194,10 +190,4 @@ private[http] class HttpRequestParser(_settings: ParserSettings, expect100continue, hostHeaderPresent, closeAfterResponseCompletion) } } else failMessageStart("Request is missing required `Host` header") - - def failWithChunkedEntityTooLong(totalBytesRead: Long): StateResult = - failEntityStream( - summary = s"Aggregated data length of chunked request entity of $totalBytesRead " + - s"bytes exceeds the configured limit of $maxContentLength bytes", - detail = "Consider increasing the value of akka.http.server.parsing.max-content-length") } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala index 1c65e7a93f..bbdf4ce020 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala @@ -98,11 +98,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: teh match { case None ⇒ clh match { case Some(`Content-Length`(contentLength)) ⇒ - if (contentLength > maxContentLength) - failMessageStart( - summary = s"Response Content-Length of $contentLength bytes exceeds the configured limit of $maxContentLength bytes", - detail = "Consider increasing the value of akka.http.client.parsing.max-content-length") - else if (contentLength == 0) finishEmptyResponse() + if (contentLength == 0) finishEmptyResponse() else if (contentLength <= input.size - bodyStart) { val cl = contentLength.toInt emitResponseStart(strictEntity(cth, input, bodyStart, cl)) @@ -116,7 +112,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: case None ⇒ emitResponseStart { entityParts ⇒ val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes } - HttpEntity.CloseDelimited(contentType(cth), data) + HttpEntity.CloseDelimited(contentType(cth), HttpEntity.limitableByteSource(data)) } setCompletionHandling(HttpMessageParser.CompletionOk) parseToCloseBody(input, bodyStart, totalBytesRead = 0) @@ -135,25 +131,10 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: } else finishEmptyResponse() } - // currently we do not check for `settings.maxContentLength` overflow def parseToCloseBody(input: ByteString, bodyStart: Int, totalBytesRead: Long): StateResult = { val newTotalBytes = totalBytesRead + math.max(0, input.length - bodyStart) - if (newTotalBytes > settings.maxContentLength) - failEntityStream( - summary = s"Aggregated data length of close-delimited response entity of $newTotalBytes " + - s"bytes exceeds the configured limit of $maxContentLength bytes", - detail = "Consider increasing the value of akka.http.client.parsing.max-content-length") - else { - if (input.length > bodyStart) - emit(EntityPart(input.drop(bodyStart).compact)) - continue(parseToCloseBody(_, _, newTotalBytes)) - } + if (input.length > bodyStart) + emit(EntityPart(input.drop(bodyStart).compact)) + continue(parseToCloseBody(_, _, newTotalBytes)) } - - def failWithChunkedEntityTooLong(totalBytesRead: Long): StateResult = - failEntityStream( - summary = s"Aggregated data length of chunked response entity of $totalBytesRead " + - s"bytes exceeds the configured limit of $maxContentLength bytes", - detail = "Consider increasing the value of akka.http.client.parsing.max-content-length") - } \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 71346f4b4c..cf1f36a537 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -6,7 +6,8 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress import java.util.Random - +import org.reactivestreams.{ Publisher, Subscriber } +import scala.util.control.NonFatal import akka.actor.{ ActorRef, Deploy, Props } import akka.event.LoggingAdapter import akka.http.ServerSettings @@ -25,9 +26,6 @@ import akka.stream.io._ import akka.stream.scaladsl._ import akka.stream.stage._ import akka.util.ByteString -import org.reactivestreams.{ Publisher, Subscriber } - -import scala.util.control.NonFatal /** * INTERNAL API @@ -86,7 +84,8 @@ private[http] object HttpServerBluePrint { headers.`Remote-Address`(RemoteAddress(remoteAddress.get)) +: hdrs else hdrs - HttpRequest(effectiveMethod, uri, effectiveHeaders, createEntity(entityParts), protocol) + val entity = createEntity(entityParts) withSizeLimit parserSettings.maxContentLength + HttpRequest(effectiveMethod, uri, effectiveHeaders, entity, protocol) case (_, src) ⇒ src.runWith(Sink.ignore) }.collect { case r: HttpRequest ⇒ r @@ -107,7 +106,7 @@ private[http] object HttpServerBluePrint { val rendererPipeline = Flow[ResponseRenderingContext] - .via(Flow[ResponseRenderingContext].transform(() ⇒ new ErrorsTo500ResponseRecovery(log)).named("recover")) // FIXME: simplify after #16394 is closed + .recover(errorResponseRecovery(log, settings)) .via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer")) .flatMapConcat(ConstantFun.scalaIdentityFunction) .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger")) @@ -245,6 +244,25 @@ private[http] object HttpServerBluePrint { } } + def errorResponseRecovery(log: LoggingAdapter, + settings: ServerSettings): PartialFunction[Throwable, ResponseRenderingContext] = { + case EntityStreamSizeException(limit, contentLength) ⇒ + val status = StatusCodes.RequestEntityTooLarge + val summary = contentLength match { + case Some(cl) ⇒ s"Request Content-Length of $cl bytes exceeds the configured limit of $limit bytes" + case None ⇒ s"Aggregated data length of request entity exceeds the configured limit of $limit bytes" + } + val info = ErrorInfo(summary, "Consider increasing the value of akka.http.server.parsing.max-content-length") + logParsingError(info withSummaryPrepended s"Illegal request, responding with status '$status'", + log, settings.parserSettings.errorLoggingVerbosity) + val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary + ResponseRenderingContext(HttpResponse(status, entity = msg), closeRequested = true) + + case NonFatal(e) ⇒ + log.error(e, "Internal server error, sending 500 response") + ResponseRenderingContext(HttpResponse(StatusCodes.InternalServerError), closeRequested = true) + } + /** * The `Expect: 100-continue` header has a special status in HTTP. * It allows the client to send an `Expect: 100-continue` header with the request and then pause request sending @@ -280,30 +298,6 @@ private[http] object HttpServerBluePrint { */ case object OneHundredContinue - final class ErrorsTo500ResponseRecovery(log: LoggingAdapter) - extends PushPullStage[ResponseRenderingContext, ResponseRenderingContext] { - - import akka.stream.stage.Context - - private[this] var errorResponse: ResponseRenderingContext = _ - - override def onPush(elem: ResponseRenderingContext, ctx: Context[ResponseRenderingContext]) = ctx.push(elem) - - override def onPull(ctx: Context[ResponseRenderingContext]) = - if (ctx.isFinishing) ctx.pushAndFinish(errorResponse) - else ctx.pull() - - override def onUpstreamFailure(error: Throwable, ctx: Context[ResponseRenderingContext]) = - error match { - case NonFatal(e) ⇒ - log.error(e, "Internal server error, sending 500 response") - errorResponse = ResponseRenderingContext(HttpResponse(StatusCodes.InternalServerError), - closeRequested = true) - ctx.absorbTermination() - case _ ⇒ ctx.fail(error) - } - } - private trait WebsocketSetup { def websocketFlow: Flow[ByteString, ByteString, Any] def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: Materializer): Unit diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/ErrorInfo.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ErrorInfo.scala index 0a4d331a9b..019510c581 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/ErrorInfo.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ErrorInfo.scala @@ -72,4 +72,8 @@ object EntityStreamException { def apply(summary: String, detail: String = ""): EntityStreamException = apply(ErrorInfo(summary, detail)) } +case class EntityStreamSizeException(limit: Long, actualSize: Option[Long] = None) extends RuntimeException { + override def toString = s"EntityStreamSizeException($limit, $actualSize)" +} + case class RequestTimeoutException(request: HttpRequest, message: String) extends RuntimeException(message) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala index 3137663624..7d7a9bbdf2 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala @@ -4,8 +4,6 @@ package akka.http.scaladsl.model -import akka.http.javadsl.model.HttpEntityStrict - import language.implicitConversions import java.io.File import java.lang.{ Iterable ⇒ JIterable, Long ⇒ JLong } @@ -13,10 +11,12 @@ import scala.concurrent.Future import scala.concurrent.duration._ import scala.collection.immutable import akka.util.ByteString -import akka.stream.Materializer import akka.stream.scaladsl._ import akka.stream.io.SynchronousFileSource +import akka.stream.stage._ +import akka.stream._ import akka.{ japi, stream } +import akka.http.javadsl.model.HttpEntityStrict import akka.http.scaladsl.util.FastFuture import akka.http.javadsl.{ model ⇒ jm } import akka.http.impl.util.JavaMapping.Implicits._ @@ -74,6 +74,29 @@ sealed trait HttpEntity extends jm.HttpEntity { */ def withContentType(contentType: ContentType): HttpEntity + /** + * Apply the given size limit to this entity by returning a new entity instance which automatically verifies that the + * data stream encapsulated by this instance produces at most `maxBytes` data bytes. In case this verification fails + * the respective stream will be terminated with an `EntityStreamException` either directly at materialization + * time (if the Content-Length is known) or whenever more data bytes than allowed have been read. + * + * When called on `Strict` entities the method will return the entity itself if the length is within the bound, + * otherwise a `Default` entity with a single element data stream. This allows for potential refinement of the + * entity size limit at a later point (before materialization of the data stream). + * + * By default all message entities produced by the HTTP layer automatically carry the limit that is defined in the + * application's `max-content-length` config setting. If the entity is transformed in a way that changes the + * Content-Length and then another limit is applied then this new limit will be evaluated against the new + * Content-Length. If the entity is transformed in a way that changes the Content-Length and no new limit is applied + * then the previous limit will be applied against the previous Content-Length. + * + * Note that the size limit applied via this method will only have any effect if the `Source` instance contained + * in this entity has been appropriately modified via the `HttpEntity.limitable` method. For all entities created + * by the HTTP layer itself this is always the case, but if you create entities yourself and would like them to + * properly respect limits defined via this method you need to make sure to apply `HttpEntity.limitable` yourself. + */ + def withSizeLimit(maxBytes: Long): HttpEntity + /** Java API */ def getDataBytes: stream.javadsl.Source[ByteString, AnyRef] = stream.javadsl.Source.fromGraph(dataBytes.asInstanceOf[Source[ByteString, AnyRef]]) @@ -95,6 +118,11 @@ sealed trait HttpEntity extends jm.HttpEntity { /* An entity that can be used for body parts */ sealed trait BodyPartEntity extends HttpEntity with jm.BodyPartEntity { def withContentType(contentType: ContentType): BodyPartEntity + + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): BodyPartEntity } /** @@ -105,7 +133,12 @@ sealed trait BodyPartEntity extends HttpEntity with jm.BodyPartEntity { sealed trait RequestEntity extends HttpEntity with jm.RequestEntity with ResponseEntity { def withContentType(contentType: ContentType): RequestEntity - override def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): RequestEntity + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): RequestEntity + + def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): RequestEntity } /** @@ -116,11 +149,22 @@ sealed trait RequestEntity extends HttpEntity with jm.RequestEntity with Respons sealed trait ResponseEntity extends HttpEntity with jm.ResponseEntity { def withContentType(contentType: ContentType): ResponseEntity - override def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): ResponseEntity + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): ResponseEntity + + def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): ResponseEntity } /* An entity that can be used for requests, responses, and body parts */ sealed trait UniversalEntity extends jm.UniversalEntity with MessageEntity with BodyPartEntity { def withContentType(contentType: ContentType): UniversalEntity + + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): UniversalEntity + def contentLength: Long def contentLengthOption: Option[Long] = Some(contentLength) @@ -188,6 +232,13 @@ object HttpEntity { def withContentType(contentType: ContentType): Strict = if (contentType == this.contentType) this else copy(contentType = contentType) + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): UniversalEntity = + if (data.length <= maxBytes) this + else Default(contentType, data.length, limitableByteSource(Source.single(data))) withSizeLimit maxBytes + override def productPrefix = "HttpEntity.Strict" } @@ -213,6 +264,12 @@ object HttpEntity { def withContentType(contentType: ContentType): Default = if (contentType == this.contentType) this else copy(contentType = contentType) + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): Default = + copy(data = data withAttributes Attributes(SizeLimit(maxBytes, Some(contentLength)))) + override def productPrefix = "HttpEntity.Default" } @@ -222,13 +279,23 @@ object HttpEntity { * INTERNAL API */ private[http] sealed trait WithoutKnownLength extends HttpEntity { + type Self <: WithoutKnownLength def contentType: ContentType def data: Source[ByteString, Any] def contentLengthOption: Option[Long] = None - def isKnownEmpty = data eq Source.empty - def dataBytes: Source[ByteString, Any] = data + + /** + * See [[HttpEntity#withSizeLimit]]. + */ + def withSizeLimit(maxBytes: Long): Self = + withData(data withAttributes Attributes(SizeLimit(maxBytes))) + + def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): Self = + withData(data via transformer) + + def withData(data: Source[ByteString, Any]): Self } /** @@ -244,8 +311,7 @@ object HttpEntity { def withContentType(contentType: ContentType): CloseDelimited = if (contentType == this.contentType) this else copy(contentType = contentType) - override def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): CloseDelimited = - HttpEntity.CloseDelimited(contentType, data via transformer) + def withData(data: Source[ByteString, Any]): CloseDelimited = copy(data = data) override def productPrefix = "HttpEntity.CloseDelimited" } @@ -256,13 +322,13 @@ object HttpEntity { */ final case class IndefiniteLength(contentType: ContentType, data: Source[ByteString, Any]) extends jm.HttpEntityIndefiniteLength with BodyPartEntity with WithoutKnownLength { + type Self = IndefiniteLength override def isIndefiniteLength: Boolean = true def withContentType(contentType: ContentType): IndefiniteLength = if (contentType == this.contentType) this else copy(contentType = contentType) - override def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): IndefiniteLength = - HttpEntity.IndefiniteLength(contentType, data via transformer) + def withData(data: Source[ByteString, Any]): IndefiniteLength = copy(data = data) override def productPrefix = "HttpEntity.IndefiniteLength" } @@ -280,6 +346,9 @@ object HttpEntity { def dataBytes: Source[ByteString, Any] = chunks.map(_.data).filter(_.nonEmpty) + def withSizeLimit(maxBytes: Long): Chunked = + copy(chunks = chunks withAttributes Attributes(SizeLimit(maxBytes))) + override def transformDataBytes(transformer: Flow[ByteString, ByteString, Any]): Chunked = { val newData = chunks.map { @@ -354,4 +423,57 @@ object HttpEntity { def getTrailerHeaders: JIterable[jm.HttpHeader] = trailer.asJava } object LastChunk extends LastChunk("", Nil) + + /** + * Turns the given source into one that respects the `withSizeLimit` calls when used as a parameter + * to entity constructors. + */ + def limitableByteSource[Mat](source: Source[ByteString, Mat]): Source[ByteString, Mat] = + limitable(source, sizeOfByteString) + + /** + * Turns the given source into one that respects the `withSizeLimit` calls when used as a parameter + * to entity constructors. + */ + def limitableChunkSource[Mat](source: Source[ChunkStreamPart, Mat]): Source[ChunkStreamPart, Mat] = + limitable(source, sizeOfChunkStreamPart) + + /** + * INTERNAL API + */ + private val sizeOfByteString: ByteString ⇒ Int = _.size + private val sizeOfChunkStreamPart: ChunkStreamPart ⇒ Int = _.data.size + + /** + * INTERNAL API + */ + private def limitable[Out, Mat](source: Source[Out, Mat], sizeOf: Out ⇒ Int): Source[Out, Mat] = + source.via(Flow[Out].transform { () ⇒ + new PushStage[Out, Out] { + var maxBytes = -1L + var bytesLeft = Long.MaxValue + + override def preStart(ctx: LifecycleContext) = + ctx.attributes.getFirst[SizeLimit] match { + case Some(SizeLimit(bytes, cl @ Some(contentLength))) ⇒ + if (contentLength > bytes) throw EntityStreamSizeException(bytes, cl) + // else we still count but never throw an error + case Some(SizeLimit(bytes, None)) ⇒ + maxBytes = bytes + bytesLeft = bytes + case None ⇒ + } + + def onPush(elem: Out, ctx: stage.Context[Out]): stage.SyncDirective = { + bytesLeft -= sizeOf(elem) + if (bytesLeft >= 0) ctx.push(elem) + else ctx.fail(EntityStreamSizeException(maxBytes)) + } + } + }.named("limitable")) + + /** + * INTERNAL API + */ + private case class SizeLimit(maxBytes: Long, contentLength: Option[Long] = None) extends Attributes.Attribute } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala index ddfb94032f..ef8838599b 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala @@ -67,9 +67,6 @@ sealed trait HttpMessage extends jm.HttpMessage { /** Returns a copy of this message with the list of headers transformed by the given function */ def mapHeaders(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): Self = withHeaders(f(headers)) - /** Returns a copy of this message with the entity transformed by the given function */ - def mapEntity(f: HttpEntity ⇒ MessageEntity): Self = withEntity(f(entity)) - /** * The content encoding as specified by the Content-Encoding header. If no Content-Encoding header is present the * default value 'identity' is returned. diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index 9150e1540d..639c20c8f2 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -4,9 +4,11 @@ package akka.http.impl.engine.client +import scala.concurrent.duration._ +import scala.reflect.ClassTag +import org.scalatest.Inside import akka.http.ClientConnectionSettings import akka.stream.io.{ SessionBytes, SslTlsOutbound, SendBytes } -import org.scalatest.Inside import akka.util.ByteString import akka.event.NoLogging import akka.stream.{ ClosedShape, ActorMaterializer } @@ -26,23 +28,14 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. "handle a request/response round-trip" which { "has a request with empty entity" in new TestSetup { - requestsSub.sendNext(HttpRequest()) - expectWireData( - """GET / HTTP/1.1 - |Host: example.com - |User-Agent: akka-http/test - | - |""") - - netInSub.expectRequest() + sendStandardRequest() sendWireData( """HTTP/1.1 200 OK |Content-Length: 0 | |""") - responsesSub.request(1) - responses.expectNext(HttpResponse()) + expectResponse() shouldEqual HttpResponse() requestsSub.sendComplete() netOut.expectComplete() @@ -71,15 +64,13 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. expectWireData("XY") sub.sendComplete() - netInSub.expectRequest() sendWireData( """HTTP/1.1 200 OK |Content-Length: 0 | |""") - responsesSub.request(1) - responses.expectNext(HttpResponse()) + expectResponse() shouldEqual HttpResponse() requestsSub.sendComplete() netOut.expectComplete() @@ -88,23 +79,14 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. } "has a response with a chunked entity" in new TestSetup { - requestsSub.sendNext(HttpRequest()) - expectWireData( - """GET / HTTP/1.1 - |Host: example.com - |User-Agent: akka-http/test - | - |""") - - netInSub.expectRequest() + sendStandardRequest() sendWireData( """HTTP/1.1 200 OK |Transfer-Encoding: chunked | |""") - responsesSub.request(1) - val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext() + val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = expectResponse() ct shouldEqual ContentTypes.`application/octet-stream` val probe = TestSubscriber.manualProbe[ChunkStreamPart]() @@ -140,15 +122,13 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") - netInSub.expectRequest() sendWireData( """HTTP/1.1 200 OK |Content-Length: 0 | |""") - responsesSub.request(1) - responses.expectNext(HttpResponse()) + expectResponse() shouldEqual HttpResponse() netOut.expectComplete() netInSub.sendComplete() @@ -166,15 +146,13 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") - netInSub.expectRequest() sendWireData( """HTTP/1.1 200 OK |Transfer-Encoding: chunked | |""") - responsesSub.request(1) - val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext() + val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = expectResponse() val probe = TestSubscriber.manualProbe[ChunkStreamPart]() chunks.runWith(Sink(probe)) @@ -216,7 +194,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. "produce proper errors" which { - "catch the entity stream being shorter than the Content-Length" in new TestSetup { + "catch the request entity stream being shorter than the Content-Length" in new TestSetup { val probe = TestPublisher.manualProbe[ByteString]() requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe)))) expectWireData( @@ -242,7 +220,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. responses.expectError(One2OneBidiFlow.OutputTruncationException) } - "catch the entity stream being longer than the Content-Length" in new TestSetup { + "catch the request entity stream being longer than the Content-Length" in new TestSetup { val probe = TestPublisher.manualProbe[ByteString]() requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe)))) expectWireData( @@ -269,15 +247,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. } "catch illegal response starts" in new TestSetup { - requestsSub.sendNext(HttpRequest()) - expectWireData( - """GET / HTTP/1.1 - |Host: example.com - |User-Agent: akka-http/test - | - |""") - - netInSub.expectRequest() + sendStandardRequest() sendWireData( """HTTP/1.2 200 OK | @@ -287,18 +257,11 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. info.summary shouldEqual "The server-side HTTP version is not supported" netOut.expectError(error) requestsSub.expectCancellation() + netInSub.expectCancellation() } "catch illegal response chunks" in new TestSetup { - requestsSub.sendNext(HttpRequest()) - expectWireData( - """GET / HTTP/1.1 - |Host: example.com - |User-Agent: akka-http/test - | - |""") - - netInSub.expectRequest() + sendStandardRequest() sendWireData( """HTTP/1.1 200 OK |Transfer-Encoding: chunked @@ -325,18 +288,11 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. responses.expectComplete() netOut.expectComplete() requestsSub.expectCancellation() + netInSub.expectCancellation() } "catch a response start truncation" in new TestSetup { - requestsSub.sendNext(HttpRequest()) - expectWireData( - """GET / HTTP/1.1 - |Host: example.com - |User-Agent: akka-http/test - | - |""") - - netInSub.expectRequest() + sendStandardRequest() sendWireData("HTTP/1.1 200 OK") netInSub.sendComplete() @@ -346,14 +302,225 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. requestsSub.expectCancellation() } } + + def isDefinedVia = afterWord("is defined via") + "support response length verification" which isDefinedVia { + import HttpEntity._ + + class LengthVerificationTest(maxContentLength: Int) extends TestSetup(maxContentLength) { + val entityBase = "0123456789ABCD" + + def sendStrictResponseWithLength(bytes: Int) = + sendWireData( + s"""HTTP/1.1 200 OK + |Content-Length: $bytes + | + |${entityBase take bytes}""") + def sendDefaultResponseWithLength(bytes: Int) = { + sendWireData( + s"""HTTP/1.1 200 OK + |Content-Length: $bytes + | + |${entityBase take 3}""") + sendWireData(entityBase.slice(3, 7)) + sendWireData(entityBase.slice(7, bytes)) + } + def sendChunkedResponseWithLength(bytes: Int) = + sendWireData( + s"""HTTP/1.1 200 OK + |Transfer-Encoding: chunked + | + |3 + |${entityBase take 3} + |4 + |${entityBase.slice(3, 7)} + |${bytes - 7} + |${entityBase.slice(7, bytes)} + |0 + | + |""") + def sendCloseDelimitedResponseWithLength(bytes: Int) = { + sendWireData( + s"""HTTP/1.1 200 OK + | + |${entityBase take 3}""") + sendWireData(entityBase.slice(3, 7)) + sendWireData(entityBase.slice(7, bytes)) + netInSub.sendComplete() + } + + implicit class XResponse(response: HttpResponse) { + def expectStrictEntityWithLength(bytes: Int) = + response shouldEqual HttpResponse( + entity = Strict(ContentTypes.`application/octet-stream`, ByteString(entityBase take bytes))) + + def expectEntity[T <: HttpEntity: ClassTag](bytes: Int) = + inside(response) { + case HttpResponse(_, _, entity: T, _) ⇒ + entity.toStrict(100.millis).awaitResult(100.millis).data.utf8String shouldEqual entityBase.take(bytes) + } + + def expectSizeErrorInEntityOfType[T <: HttpEntity: ClassTag](limit: Int, actualSize: Option[Long] = None) = + inside(response) { + case HttpResponse(_, _, entity: T, _) ⇒ + def gatherBytes = entity.dataBytes.runFold(ByteString.empty)(_ ++ _).awaitResult(100.millis) + (the[Exception] thrownBy gatherBytes).getCause shouldEqual EntityStreamSizeException(limit, actualSize) + } + } + } + + "the config setting (strict entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendStandardRequest() + sendStrictResponseWithLength(10) + expectResponse().expectStrictEntityWithLength(10) + + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStandardRequest() + sendStrictResponseWithLength(11) + expectResponse().expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "the config setting (default entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendStandardRequest() + sendDefaultResponseWithLength(10) + expectResponse().expectEntity[Default](10) + + sendStandardRequest() + sendDefaultResponseWithLength(11) + expectResponse().expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "the config setting (chunked entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendStandardRequest() + sendChunkedResponseWithLength(10) + expectResponse().expectEntity[Chunked](10) + + sendStandardRequest() + sendChunkedResponseWithLength(11) + expectResponse().expectSizeErrorInEntityOfType[Chunked](limit = 10) + } + + "the config setting (close-delimited entity)" in { + new LengthVerificationTest(maxContentLength = 10) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(10) + expectResponse().expectEntity[CloseDelimited](10) + } + new LengthVerificationTest(maxContentLength = 10) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(11) + expectResponse().expectSizeErrorInEntityOfType[CloseDelimited](limit = 10) + } + } + + "a smaller programmatically-set limit (strict entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendStandardRequest() + sendStrictResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectStrictEntityWithLength(10) + + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStandardRequest() + sendStrictResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10) + .expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "a smaller programmatically-set limit (default entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendStandardRequest() + sendDefaultResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[Default](10) + + sendStandardRequest() + sendDefaultResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10) + .expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "a smaller programmatically-set limit (chunked entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendStandardRequest() + sendChunkedResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[Chunked](10) + + sendStandardRequest() + sendChunkedResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10).expectSizeErrorInEntityOfType[Chunked](limit = 10) + } + + "a smaller programmatically-set limit (close-delimited entity)" in { + new LengthVerificationTest(maxContentLength = 12) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[CloseDelimited](10) + } + new LengthVerificationTest(maxContentLength = 12) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10).expectSizeErrorInEntityOfType[CloseDelimited](limit = 10) + } + } + + "a larger programmatically-set limit (strict entity)" in new LengthVerificationTest(maxContentLength = 8) { + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStandardRequest() + sendStrictResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[Default](10) + + sendStandardRequest() + sendStrictResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10) + .expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "a larger programmatically-set limit (default entity)" in new LengthVerificationTest(maxContentLength = 8) { + sendStandardRequest() + sendDefaultResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[Default](10) + + sendStandardRequest() + sendDefaultResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10) + .expectSizeErrorInEntityOfType[Default](limit = 10, actualSize = Some(11)) + } + + "a larger programmatically-set limit (chunked entity)" in new LengthVerificationTest(maxContentLength = 8) { + sendStandardRequest() + sendChunkedResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[Chunked](10) + + sendStandardRequest() + sendChunkedResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10) + .expectSizeErrorInEntityOfType[Chunked](limit = 10) + } + + "a larger programmatically-set limit (close-delimited entity)" in { + new LengthVerificationTest(maxContentLength = 8) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(10) + expectResponse().mapEntity(_ withSizeLimit 10).expectEntity[CloseDelimited](10) + } + new LengthVerificationTest(maxContentLength = 8) { + sendStandardRequest() + sendCloseDelimitedResponseWithLength(11) + expectResponse().mapEntity(_ withSizeLimit 10).expectSizeErrorInEntityOfType[CloseDelimited](limit = 10) + } + } + } } - class TestSetup { + class TestSetup(maxResponseContentLength: Int = -1) { val requests = TestPublisher.manualProbe[HttpRequest]() val responses = TestSubscriber.manualProbe[HttpResponse]() - def settings = ClientConnectionSettings(system) - .copy(userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test"))))) + def settings = { + val s = ClientConnectionSettings(system) + .copy(userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test"))))) + if (maxResponseContentLength < 0) s + else s.copy(parserSettings = s.parserSettings.copy(maxContentLength = maxResponseContentLength)) + } val (netOut, netIn) = { val netOut = TestSubscriber.manualProbe[ByteString] @@ -383,6 +550,9 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. val requestsSub = requests.expectSubscription() val responsesSub = responses.expectSubscription() + requestsSub.expectRequest(16) + netInSub.expectRequest(16) + def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII")) def sendWireData(data: ByteString): Unit = netInSub.sendNext(data) @@ -392,5 +562,20 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. } def closeNetworkInput(): Unit = netInSub.sendComplete() + + def sendStandardRequest() = { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com + |User-Agent: akka-http/test + | + |""") + } + + def expectResponse() = { + responsesSub.request(1) + responses.expectNext() + } } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala index 311cbb2ea3..63b9225e18 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala @@ -4,6 +4,11 @@ package akka.http.impl.engine.parsing +import com.typesafe.config.{ Config, ConfigFactory } +import scala.concurrent.Future +import scala.concurrent.duration._ +import org.scalatest.matchers.Matcher +import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers } import akka.actor.ActorSystem import akka.http.ParserSettings import akka.http.impl.engine.parsing.ParserOutput._ @@ -21,12 +26,6 @@ import akka.http.scaladsl.util.FastFuture._ import akka.stream.ActorMaterializer import akka.stream.scaladsl._ import akka.util.ByteString -import com.typesafe.config.{ Config, ConfigFactory } -import org.scalatest.matchers.Matcher -import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers } - -import scala.concurrent.Future -import scala.concurrent.duration._ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" @@ -165,311 +164,263 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { |""" should parseTo(HttpRequest(GET, Uri("http://x//foo").toHttpRequestTargetOriginForm, protocol = `HTTP/1.0`)) closeAfterResponseCompletion shouldEqual Seq(true) } + } - "properly parse a chunked request" - { - val start = - """PATCH /data HTTP/1.1 - |Transfer-Encoding: chunked - |Connection: lalelu - |Content-Type: application/pdf - |Host: ping - | - |""" - val baseRequest = HttpRequest(PATCH, "/data", List(Connection("lalelu"), Host("ping"))) - - "request start" in new Test { - Seq(start, "rest") should generalMultiParseTo( - Right(baseRequest.withEntity(HttpEntity.Chunked(`application/pdf`, source()))), - Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "message chunk with and without extension" in new Test { - Seq(start + - """3 - |abc - |10;some=stuff;bla - |0123456789ABCDEF - |""", - "10;foo=", - """bar - |0123456789ABCDEF - |A - |0123456789""", - """ - |0 - | - |""") should generalMultiParseTo( - Right(baseRequest.withEntity(Chunked(`application/pdf`, source( - Chunk(ByteString("abc")), - Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), - Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), - Chunk(ByteString("0123456789"), ""), - LastChunk))))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "message end" in new Test { - Seq(start, - """0 - | - |""") should generalMultiParseTo( - Right(baseRequest.withEntity(Chunked(`application/pdf`, source(LastChunk))))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "message end with extension and trailer" in new Test { - Seq(start, - """000;nice=true - |Foo: pip - | apo - |Bar: xyz - | - |""") should generalMultiParseTo( - Right(baseRequest.withEntity(Chunked(`application/pdf`, - source(LastChunk("nice=true", List(RawHeader("Foo", "pip apo"), RawHeader("Bar", "xyz")))))))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "don't overflow the stack for large buffers of chunks" in new Test { - override val awaitAtMost = 3000.millis - - val x = NotEnoughDataException - val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m` - val oneChunk = "1\r\nz\n" - val manyChunks = (oneChunk * numChunks) + "0\r\n" - - val parser = newParser - val result = multiParse(newParser)(Seq(prep(start + manyChunks))) - val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity - val strictChunks = chunks.grouped(100000).runWith(Sink.head).awaitResult(awaitAtMost) - strictChunks.size shouldEqual numChunks - } - } - - "properly parse a chunked request with additional transfer encodings" in new Test { + "properly parse a chunked request" - { + val start = """PATCH /data HTTP/1.1 - |Transfer-Encoding: fancy, chunked + |Transfer-Encoding: chunked + |Connection: lalelu |Content-Type: application/pdf |Host: ping | - |0 - | - |""" should parseTo(HttpRequest(PATCH, "/data", List(`Transfer-Encoding`(TransferEncodings.Extension("fancy")), - Host("ping")), HttpEntity.Chunked(`application/pdf`, source(LastChunk)))) + |""" + val baseRequest = HttpRequest(PATCH, "/data", List(Connection("lalelu"), Host("ping"))) + + "request start" in new Test { + Seq(start, "rest") should generalMultiParseTo( + Right(baseRequest.withEntity(HttpEntity.Chunked(`application/pdf`, source()))), + Left(EntityStreamError(ErrorInfo("Illegal character 'r' in chunk start")))) closeAfterResponseCompletion shouldEqual Seq(false) } - "support `rawRequestUriHeader` setting" in new Test { - override protected def newParser: HttpRequestParser = - new HttpRequestParser(parserSettings, rawRequestUriHeader = true, _headerParser = HttpHeaderParser(parserSettings)()) - - """GET /f%6f%6fbar?q=b%61z HTTP/1.1 - |Host: ping - |Content-Type: application/pdf - | - |""" should parseTo( - HttpRequest( - GET, - "/foobar?q=b%61z", - List( - `Raw-Request-URI`("/f%6f%6fbar?q=b%61z"), - Host("ping")), - HttpEntity.empty(`application/pdf`))) - } - - "reject a message chunk with" - { - val start = - """PATCH /data HTTP/1.1 - |Transfer-Encoding: chunked - |Connection: lalelu - |Host: ping + "message chunk with and without extension" in new Test { + Seq(start + + """3 + |abc + |10;some=stuff;bla + |0123456789ABCDEF + |""", + "10;foo=", + """bar + |0123456789ABCDEF + |A + |0123456789""", + """ + |0 | - |""" - val baseRequest = HttpRequest(PATCH, "/data", List(Connection("lalelu"), Host("ping")), - HttpEntity.Chunked(`application/octet-stream`, source())) - - "an illegal char after chunk size" in new Test { - Seq(start, - """15 ; - |""") should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("Illegal character ' ' in chunk start")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "an illegal char in chunk size" in new Test { - Seq(start, "bla") should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("Illegal character 'l' in chunk start")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "too-long chunk extension" in new Test { - Seq(start, "3;" + ("x" * 257)) should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "too-large chunk size" in new Test { - Seq(start, - """1a2b3c4d5e - |""") should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "an illegal chunk termination" in new Test { - Seq(start, - """3 - |abcde""") should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("Illegal chunk termination")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } - - "an illegal header in the trailer" in new Test { - Seq(start, - """0 - |F@oo: pip""") should generalMultiParseTo(Right(baseRequest), - Left(EntityStreamError(ErrorInfo("Illegal character '@' in header name")))) - closeAfterResponseCompletion shouldEqual Seq(false) - } + |""") should generalMultiParseTo( + Right(baseRequest.withEntity(Chunked(`application/pdf`, source( + Chunk(ByteString("abc")), + Chunk(ByteString("0123456789ABCDEF"), "some=stuff;bla"), + Chunk(ByteString("0123456789ABCDEF"), "foo=bar"), + Chunk(ByteString("0123456789"), ""), + LastChunk))))) + closeAfterResponseCompletion shouldEqual Seq(false) } - "reject a request with" - { - "an illegal HTTP method" in new Test { - "get " should parseToError(NotImplemented, ErrorInfo("Unsupported HTTP method", "get")) - "GETX " should parseToError(NotImplemented, ErrorInfo("Unsupported HTTP method", "GETX")) - } + "message end" in new Test { + Seq(start, + """0 + | + |""") should generalMultiParseTo( + Right(baseRequest.withEntity(Chunked(`application/pdf`, source(LastChunk))))) + closeAfterResponseCompletion shouldEqual Seq(false) + } - "a too long HTTP method" in new Test { - "ABCDEFGHIJKLMNOPQ " should - parseToError(BadRequest, - ErrorInfo( - "Unsupported HTTP method", - "HTTP method too long (started with 'ABCDEFGHIJKLMNOP'). Increase `akka.http.server.parsing.max-method-length` to support HTTP methods with more characters.")) - } + "message end with extension and trailer" in new Test { + Seq(start, + """000;nice=true + |Foo: pip + | apo + |Bar: xyz + | + |""") should generalMultiParseTo( + Right(baseRequest.withEntity(Chunked(`application/pdf`, + source(LastChunk("nice=true", List(RawHeader("Foo", "pip apo"), RawHeader("Bar", "xyz")))))))) + closeAfterResponseCompletion shouldEqual Seq(false) + } - "two Content-Length headers" in new Test { - """GET / HTTP/1.1 + "don't overflow the stack for large buffers of chunks" in new Test { + override val awaitAtMost = 3000.millis + + val x = NotEnoughDataException + val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m` + val oneChunk = "1\r\nz\n" + val manyChunks = (oneChunk * numChunks) + "0\r\n" + + val parser = newParser + val result = multiParse(newParser)(Seq(prep(start + manyChunks))) + val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity + val strictChunks = chunks.grouped(100000).runWith(Sink.head).awaitResult(awaitAtMost) + strictChunks.size shouldEqual numChunks + } + } + + "properly parse a chunked request with additional transfer encodings" in new Test { + """PATCH /data HTTP/1.1 + |Transfer-Encoding: fancy, chunked + |Content-Type: application/pdf + |Host: ping + | + |0 + | + |""" should parseTo(HttpRequest(PATCH, "/data", List(`Transfer-Encoding`(TransferEncodings.Extension("fancy")), + Host("ping")), HttpEntity.Chunked(`application/pdf`, source(LastChunk)))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "support `rawRequestUriHeader` setting" in new Test { + override protected def newParser: HttpRequestParser = + new HttpRequestParser(parserSettings, rawRequestUriHeader = true, _headerParser = HttpHeaderParser(parserSettings)()) + + """GET /f%6f%6fbar?q=b%61z HTTP/1.1 + |Host: ping + |Content-Type: application/pdf + | + |""" should parseTo( + HttpRequest( + GET, + "/foobar?q=b%61z", + List( + `Raw-Request-URI`("/f%6f%6fbar?q=b%61z"), + Host("ping")), + HttpEntity.empty(`application/pdf`))) + } + + "reject a message chunk with" - { + val start = + """PATCH /data HTTP/1.1 + |Transfer-Encoding: chunked + |Connection: lalelu + |Host: ping + | + |""" + val baseRequest = HttpRequest(PATCH, "/data", List(Connection("lalelu"), Host("ping")), + HttpEntity.Chunked(`application/octet-stream`, source())) + + "an illegal char after chunk size" in new Test { + Seq(start, + """15 ; + |""") should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("Illegal character ' ' in chunk start")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "an illegal char in chunk size" in new Test { + Seq(start, "bla") should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("Illegal character 'l' in chunk start")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "too-long chunk extension" in new Test { + Seq(start, "3;" + ("x" * 257)) should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("HTTP chunk extension length exceeds configured limit of 256 characters")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "too-large chunk size" in new Test { + Seq(start, + """1a2b3c4d5e + |""") should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("HTTP chunk size exceeds the configured limit of 1048576 bytes")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "an illegal chunk termination" in new Test { + Seq(start, + """3 + |abcde""") should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("Illegal chunk termination")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + + "an illegal header in the trailer" in new Test { + Seq(start, + """0 + |F@oo: pip""") should generalMultiParseTo(Right(baseRequest), + Left(EntityStreamError(ErrorInfo("Illegal character '@' in header name")))) + closeAfterResponseCompletion shouldEqual Seq(false) + } + } + + "reject a request with" - { + "an illegal HTTP method" in new Test { + "get " should parseToError(NotImplemented, ErrorInfo("Unsupported HTTP method", "get")) + "GETX " should parseToError(NotImplemented, ErrorInfo("Unsupported HTTP method", "GETX")) + } + + "a too long HTTP method" in new Test { + "ABCDEFGHIJKLMNOPQ " should + parseToError(BadRequest, + ErrorInfo( + "Unsupported HTTP method", + "HTTP method too long (started with 'ABCDEFGHIJKLMNOP'). Increase `akka.http.server.parsing.max-method-length` to support HTTP methods with more characters.")) + } + + "two Content-Length headers" in new Test { + """GET / HTTP/1.1 |Content-Length: 3 |Content-Length: 4 | |foo""" should parseToError(BadRequest, - ErrorInfo("HTTP message must not contain more than one Content-Length header")) - } + ErrorInfo("HTTP message must not contain more than one Content-Length header")) + } - "a too-long URI" in new Test { - "GET /23456789012345678901 HTTP/1.1" should parseToError(RequestUriTooLong, - ErrorInfo("URI length exceeds the configured limit of 20 characters")) - } + "a too-long URI" in new Test { + "GET /23456789012345678901 HTTP/1.1" should parseToError(RequestUriTooLong, + ErrorInfo("URI length exceeds the configured limit of 20 characters")) + } - "HTTP version 1.2" in new Test { - """GET / HTTP/1.2 + "HTTP version 1.2" in new Test { + """GET / HTTP/1.2 |""" should parseToError(HTTPVersionNotSupported, - ErrorInfo("The server does not support the HTTP protocol version used in the request.")) - } + ErrorInfo("The server does not support the HTTP protocol version used in the request.")) + } - "with an illegal char in a header name" in new Test { - """GET / HTTP/1.1 + "with an illegal char in a header name" in new Test { + """GET / HTTP/1.1 |User@Agent: curl/7.19.7""" should parseToError(BadRequest, ErrorInfo("Illegal character '@' in header name")) - } + } - "with a too-long header name" in new Test { - """|GET / HTTP/1.1 + "with a too-long header name" in new Test { + """|GET / HTTP/1.1 |UserxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxAgent: curl/7.19.7""" should parseToError( - BadRequest, ErrorInfo("HTTP header name exceeds the configured limit of 64 characters")) - } + BadRequest, ErrorInfo("HTTP header name exceeds the configured limit of 64 characters")) + } - "with a too-long header-value" in new Test { - """|GET / HTTP/1.1 + "with a too-long header-value" in new Test { + """|GET / HTTP/1.1 |Fancy: 123456789012345678901234567890123""" should parseToError(BadRequest, - ErrorInfo("HTTP header value exceeds the configured limit of 32 characters")) - } + ErrorInfo("HTTP header value exceeds the configured limit of 32 characters")) + } - "with an invalid Content-Length header value" in new Test { - """GET / HTTP/1.0 + "with an invalid Content-Length header value" in new Test { + """GET / HTTP/1.0 |Content-Length: 1.5 | |abc""" should parseToError(BadRequest, ErrorInfo("Illegal `Content-Length` header value")) - } + } - "with Content-Length > Long.MaxSize" in new Test { - // content-length = (Long.MaxValue + 1) * 10, which is 0 when calculated overflow - """PUT /resource/yes HTTP/1.1 + "with Content-Length > Long.MaxSize" in new Test { + // content-length = (Long.MaxValue + 1) * 10, which is 0 when calculated overflow + """PUT /resource/yes HTTP/1.1 |Content-length: 92233720368547758080 |Host: x | |""" should parseToError(400: StatusCode, ErrorInfo("`Content-Length` header value must not exceed 63-bit integer range")) - } + } - "with entity length > max-content-length" - { - "for Default entity" in new Test { - """PUT /resource/yes HTTP/1.1 - |Content-length: 101 - |Host: x - | - |""" should parseToError(413: StatusCode, - ErrorInfo("Request Content-Length of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.server.parsing.max-content-length")) - - override protected def parserSettings: ParserSettings = super.parserSettings.copy(maxContentLength = 100) - } - - "for Chunked entity" in new Test { - def request(dataElements: ByteString*) = HttpRequest(PUT, "/", List(Host("x")), - HttpEntity.Chunked(`application/octet-stream`, source(dataElements.map(ChunkStreamPart(_)): _*))) - - Seq( - """PUT / HTTP/1.1 - |Transfer-Encoding: chunked - |Host: x - | - |65 - |abc""") should generalMultiParseTo(Right(request()), - Left( - EntityStreamError( - ErrorInfo("Aggregated data length of chunked request entity of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.server.parsing.max-content-length")))) - - Seq( - """PUT / HTTP/1.1 - |Transfer-Encoding: chunked - |Host: x - | - |1 - |a - |""", - """64 - |a""") should generalMultiParseTo(Right(request(ByteString("a"))), - Left(EntityStreamError( - ErrorInfo("Aggregated data length of chunked request entity of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.server.parsing.max-content-length")))) - - override protected def parserSettings: ParserSettings = super.parserSettings.copy(maxContentLength = - 100) - } - } - - "with an illegal entity using CONNECT" in new Test { - """CONNECT /resource/yes HTTP/1.1 + "with an illegal entity using CONNECT" in new Test { + """CONNECT /resource/yes HTTP/1.1 |Transfer-Encoding: chunked |Host: x | |""" should parseToError(422: StatusCode, ErrorInfo("CONNECT requests must not have an entity")) - } - "with an illegal entity using HEAD" in new Test { - """HEAD /resource/yes HTTP/1.1 + } + "with an illegal entity using HEAD" in new Test { + """HEAD /resource/yes HTTP/1.1 |Content-length: 3 |Host: x | |foo""" should parseToError(422: StatusCode, ErrorInfo("HEAD requests must not have an entity")) - } - "with an illegal entity using TRACE" in new Test { - """TRACE /resource/yes HTTP/1.1 + } + "with an illegal entity using TRACE" in new Test { + """TRACE /resource/yes HTTP/1.1 |Transfer-Encoding: chunked |Host: x | |""" should parseToError(422: StatusCode, ErrorInfo("TRACE requests must not have an entity")) - } } } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala index bd4adfded5..34b293f7ac 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala @@ -235,71 +235,6 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { Seq("HTTP/1.1 200\r\nContent-Length: 0\r\n\r\n") should generalMultiParseTo(Left(MessageStartError( 400: StatusCode, ErrorInfo("Status code misses trailing space")))) } - - "with entity length > max-content-length" - { - def response(dataElements: ByteString*) = HttpResponse(200, Nil, - HttpEntity.Chunked(`application/octet-stream`, Source(dataElements.map(ChunkStreamPart(_)).toVector))) - - "for Default entity" in new Test { - Seq("""HTTP/1.1 200 OK - |Content-length: 101 - | - |""") should generalMultiParseTo(Left( - MessageStartError(400: StatusCode, - ErrorInfo( - "Response Content-Length of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.client.parsing.max-content-length")))) - - override protected def parserSettings: ParserSettings = super.parserSettings.copy(maxContentLength = 100) - } - - "for CloseDelimited entity" in new Test { - Seq( - """HTTP/1.1 200 OK - | - |abcdef""") should generalMultiParseTo(Right(response()), - Left(EntityStreamError( - ErrorInfo("Aggregated data length of close-delimited response entity of 6 bytes exceeds the configured limit of 5 bytes", - "Consider increasing the value of akka.http.client.parsing.max-content-length")))) - - Seq( - """HTTP/1.1 200 OK - | - |a""", "bcdef") should generalMultiParseTo(Right(response(ByteString("a"))), - Left(EntityStreamError( - ErrorInfo("Aggregated data length of close-delimited response entity of 6 bytes exceeds the configured limit of 5 bytes", - "Consider increasing the value of akka.http.client.parsing.max-content-length")))) - - override protected def parserSettings: ParserSettings = super.parserSettings.copy(maxContentLength = 5) - } - - "for Chunked entity" in new Test { - Seq( - """HTTP/1.1 200 OK - |Transfer-Encoding: chunked - | - |65 - |abc""") should generalMultiParseTo(Right(response()), - Left(EntityStreamError( - ErrorInfo("Aggregated data length of chunked response entity of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.client.parsing.max-content-length")))) - - Seq( - """HTTP/1.1 200 OK - |Transfer-Encoding: chunked - | - |1 - |a - |""", - """64 - |a""") should generalMultiParseTo(Right(response(ByteString("a"))), - Left(EntityStreamError( - ErrorInfo("Aggregated data length of chunked response entity of 101 bytes exceeds the configured limit of 100 bytes", - "Consider increasing the value of akka.http.client.parsing.max-content-length")))) - - override protected def parserSettings: ParserSettings = super.parserSettings.copy(maxContentLength = 100) - } - } } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index b1b3db4393..4e9d37f71c 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -5,9 +5,8 @@ package akka.http.impl.engine.server import java.net.{ InetAddress, InetSocketAddress } - import akka.http.ServerSettings - +import scala.reflect.ClassTag import scala.util.Random import scala.annotation.tailrec import scala.concurrent.duration._ @@ -34,7 +33,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") - expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest() shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) } "deliver a request as soon as all headers are received" in new TestSetup { @@ -44,7 +43,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -87,7 +86,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |abcdef |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -123,7 +122,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcdefghijkl""") - expectRequest shouldEqual + expectRequest() shouldEqual HttpRequest( method = POST, uri = "http://example.com/strict", @@ -138,7 +137,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcdef""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -161,7 +160,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |abcdef |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -182,7 +181,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcdefghijkl""") - expectRequest shouldEqual + expectRequest() shouldEqual HttpRequest( method = POST, uri = "http://example.com/strict", @@ -195,7 +194,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |mnopqrstuvwx""") - expectRequest shouldEqual + expectRequest() shouldEqual HttpRequest( method = POST, uri = "http://example.com/next-strict", @@ -210,7 +209,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcdef""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -232,7 +231,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcde""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ data shouldEqual ByteString("abcde") } @@ -247,7 +246,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |abcdef |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -270,7 +269,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcde""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ data shouldEqual ByteString("abcde") } @@ -283,7 +282,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |abcdef""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -306,7 +305,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |abcdef |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -327,7 +326,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Length: 12 | |abcdef""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -348,7 +347,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |6 |abcdef |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -367,7 +366,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""") - expectRequest shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest() shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com"))) } "keep HEAD request when transparent-head-requests are disabled" in new TestSetup { @@ -376,7 +375,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""") - expectRequest shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest() shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com"))) } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Strict)" in new TestSetup { @@ -384,7 +383,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com | |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(GET, _, _, _, _) ⇒ responses.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd")))) expectResponseWithWipedDate( @@ -404,7 +403,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") val data = TestPublisher.manualProbe[ByteString]() - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(GET, _, _, _, _) ⇒ responses.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) val dataSub = data.expectSubscription() @@ -426,7 +425,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") val data = TestPublisher.manualProbe[ByteString]() - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(GET, _, _, _, _) ⇒ responses.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() @@ -449,7 +448,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") val data = TestPublisher.manualProbe[ChunkStreamPart]() - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(GET, _, _, _, _) ⇒ responses.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() @@ -472,7 +471,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") val data = TestPublisher.manualProbe[ByteString]() - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(GET, _, _, _, _) ⇒ responses.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() @@ -489,7 +488,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Length: 16 | |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() @@ -525,7 +524,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Transfer-Encoding: chunked | |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, Chunked(ContentType(`application/octet-stream`, None), data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() @@ -566,7 +565,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Length: 16 | |""") - inside(expectRequest) { + inside(expectRequest()) { case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ responses.sendNext(HttpResponse(entity = "Yeah")) expectResponseWithWipedDate( @@ -587,7 +586,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""".stripMarginWithNewline("\r\n")) - expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) + expectRequest() shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) responses.expectRequest() responses.sendError(new RuntimeException("CRASH BOOM BANG")) @@ -609,7 +608,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") - val HttpRequest(POST, _, _, entity, _) = expectRequest + val HttpRequest(POST, _, _, entity, _) = expectRequest() responses.sendNext(HttpResponse(entity = entity)) responses.sendComplete() @@ -644,7 +643,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") - expectRequest shouldEqual HttpRequest(uri = "http://example.com//foo", headers = List(Host("example.com"))) + expectRequest() shouldEqual HttpRequest(uri = "http://example.com//foo", headers = List(Host("example.com"))) } "use default-host-header for HTTP/1.0 requests" in new TestSetup { @@ -652,10 +651,11 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""") - expectRequest shouldEqual HttpRequest(uri = "http://example.com/abc", protocol = HttpProtocols.`HTTP/1.0`) + expectRequest() shouldEqual HttpRequest(uri = "http://example.com/abc", protocol = HttpProtocols.`HTTP/1.0`) override def settings: ServerSettings = super.settings.copy(defaultHostHeader = Host("example.com")) } + "fail an HTTP/1.0 request with 400 if no default-host-header is set" in new TestSetup { send("""GET /abc HTTP/1.0 | @@ -686,12 +686,196 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |""".stripMarginWithNewline("\r\n")) - val request = expectRequest + val request = expectRequest() request.headers should contain(`Remote-Address`(RemoteAddress(theAddress, Some(8080)))) } + + def isDefinedVia = afterWord("is defined via") + "support request length verification" which isDefinedVia { + + class LengthVerificationTest(maxContentLength: Int) extends TestSetup(maxContentLength) { + val entityBase = "0123456789ABCD" + def sendStrictRequestWithLength(bytes: Int) = + send(s"""POST /foo HTTP/1.1 + |Host: example.com + |Content-Length: $bytes + | + |${entityBase take bytes}""") + def sendDefaultRequestWithLength(bytes: Int) = { + send(s"""POST /foo HTTP/1.1 + |Host: example.com + |Content-Length: $bytes + | + |${entityBase take 3}""") + send(entityBase.slice(3, 7)) + send(entityBase.slice(7, bytes)) + } + def sendChunkedRequestWithLength(bytes: Int) = + send(s"""POST /foo HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |3 + |${entityBase take 3} + |4 + |${entityBase.slice(3, 7)} + |${bytes - 7} + |${entityBase.slice(7, bytes)} + |0 + | + |""") + + implicit class XRequest(request: HttpRequest) { + def expectEntity[T <: HttpEntity: ClassTag](bytes: Int) = + inside(request) { + case HttpRequest(POST, _, _, entity: T, _) ⇒ + entity.toStrict(100.millis).awaitResult(100.millis).data.utf8String shouldEqual entityBase.take(bytes) + } + + def expectDefaultEntityWithSizeError(limit: Int, actualSize: Int) = + inside(request) { + case HttpRequest(POST, _, _, entity @ HttpEntity.Default(_, `actualSize`, _), _) ⇒ + val error = the[Exception] + .thrownBy(entity.dataBytes.runFold(ByteString.empty)(_ ++ _).awaitResult(100.millis)) + .getCause + error shouldEqual EntityStreamSizeException(limit, Some(actualSize)) + + responses.expectRequest() + responses.sendError(error.asInstanceOf[Exception]) + + expectResponseWithWipedDate( + s"""HTTP/1.1 413 Request Entity Too Large + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 75 + | + |Request Content-Length of $actualSize bytes exceeds the configured limit of $limit bytes""") + } + + def expectChunkedEntityWithSizeError(limit: Int) = + inside(request) { + case HttpRequest(POST, _, _, entity: HttpEntity.Chunked, _) ⇒ + val error = the[Exception] + .thrownBy(entity.dataBytes.runFold(ByteString.empty)(_ ++ _).awaitResult(100.millis)) + .getCause + error shouldEqual EntityStreamSizeException(limit, None) + + responses.expectRequest() + responses.sendError(error.asInstanceOf[Exception]) + + expectResponseWithWipedDate( + s"""HTTP/1.1 413 Request Entity Too Large + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Type: text/plain; charset=UTF-8 + |Content-Length: 81 + | + |Aggregated data length of request entity exceeds the configured limit of $limit bytes""") + } + } + } + + "the config setting (strict entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendStrictRequestWithLength(10) + expectRequest().expectEntity[HttpEntity.Strict](10) + + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStrictRequestWithLength(11) + expectRequest().expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "the config setting (default entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendDefaultRequestWithLength(10) + expectRequest().expectEntity[HttpEntity.Default](10) + + sendDefaultRequestWithLength(11) + expectRequest().expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "the config setting (chunked entity)" in new LengthVerificationTest(maxContentLength = 10) { + sendChunkedRequestWithLength(10) + expectRequest().expectEntity[HttpEntity.Chunked](10) + + sendChunkedRequestWithLength(11) + expectRequest().expectChunkedEntityWithSizeError(limit = 10) + } + + "a smaller programmatically-set limit (strict entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendStrictRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Strict](10) + + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStrictRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "a smaller programmatically-set limit (default entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendDefaultRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Default](10) + + sendDefaultRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "a smaller programmatically-set limit (chunked entity)" in new LengthVerificationTest(maxContentLength = 12) { + sendChunkedRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Chunked](10) + + sendChunkedRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectChunkedEntityWithSizeError(limit = 10) + } + + "a larger programmatically-set limit (strict entity)" in new LengthVerificationTest(maxContentLength = 8) { + // entities that would be strict but have a Content-Length > the configured maximum are delivered + // as single element Default entities! + sendStrictRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Default](10) + + sendStrictRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "a larger programmatically-set limit (default entity)" in new LengthVerificationTest(maxContentLength = 8) { + sendDefaultRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Default](10) + + sendDefaultRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + + "a larger programmatically-set limit (chunked entity)" in new LengthVerificationTest(maxContentLength = 8) { + sendChunkedRequestWithLength(10) + expectRequest().mapEntity(_ withSizeLimit 10).expectEntity[HttpEntity.Chunked](10) + + sendChunkedRequestWithLength(11) + expectRequest().mapEntity(_ withSizeLimit 10).expectChunkedEntityWithSizeError(limit = 10) + } + + "the config setting applied before another attribute (default entity)" in new LengthVerificationTest(maxContentLength = 10) { + def nameDataSource(name: String): RequestEntity ⇒ RequestEntity = { + case x: HttpEntity.Default ⇒ x.copy(data = x.data named name) + } + sendDefaultRequestWithLength(10) + expectRequest().mapEntity(nameDataSource("foo")).expectEntity[HttpEntity.Default](10) + + sendDefaultRequestWithLength(11) + expectRequest().mapEntity(nameDataSource("foo")).expectDefaultEntityWithSizeError(limit = 10, actualSize = 11) + } + } } - class TestSetup extends HttpServerTestSetupBase { + class TestSetup(maxContentLength: Int = -1) extends HttpServerTestSetupBase { implicit def system = spec.system implicit def materializer = spec.materializer + + override def settings = { + val s = super.settings + if (maxContentLength < 0) s + else s.copy(parserSettings = s.parserSettings.copy(maxContentLength = maxContentLength)) + } } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala index dd279f9717..9f88bcdb54 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala @@ -5,23 +5,16 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress - import akka.http.impl.engine.ws.ByteStringSinkProbe import akka.stream.io.{ SendBytes, SslTlsOutbound, SessionBytes } - import scala.concurrent.duration.FiniteDuration -import scala.concurrent.duration._ - import akka.actor.ActorSystem import akka.event.NoLogging import akka.util.ByteString - import akka.stream.{ ClosedShape, Materializer } import akka.stream.scaladsl._ import akka.stream.testkit.{ TestPublisher, TestSubscriber } - import akka.http.impl.util._ - import akka.http.ServerSettings import akka.http.scaladsl.model.headers.{ ProductVersion, Server } import akka.http.scaladsl.model.{ HttpResponse, HttpRequest } @@ -56,7 +49,7 @@ abstract class HttpServerTestSetupBase { def expectResponseWithWipedDate(expected: String): Unit = { val trimmed = expected.stripMarginWithNewline("\r\n") - // XXXX = 4 bytes, ISO Date Time String = 29 bytes => need to request 15 bytes more than expected string + // XXXX = 4 bytes, ISO Date Time String = 29 bytes => need to request 25 bytes more than expected string val expectedSize = ByteString(trimmed, "utf8").length + 25 val received = wipeDate(netOut.expectBytes(expectedSize).utf8String) assert(received == trimmed, s"Expected request '$trimmed' but got '$received'") @@ -68,7 +61,7 @@ abstract class HttpServerTestSetupBase { case s ⇒ s }.mkString("\n") - def expectRequest: HttpRequest = requests.requestNext() + def expectRequest(): HttpRequest = requests.requestNext() def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max) def expectSubscribe(): Unit = netOut.expectComplete() def expectSubscribeAndNetworkClose(): Unit = netOut.expectSubscriptionAndComplete() diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala index 458f72e75f..ce9b6ddc7e 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -32,7 +32,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp | |""") - val request = expectRequest + val request = expectRequest() val upgrade = request.header[UpgradeToWebsocket] upgrade.isDefined shouldBe true @@ -78,7 +78,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp | |""") - val request = expectRequest + val request = expectRequest() val upgrade = request.header[UpgradeToWebsocket] upgrade.isDefined shouldBe true diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index 0863a2b2cd..a6815c12db 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -5,9 +5,7 @@ package akka.stream import akka.event.Logging import scala.annotation.tailrec -import scala.collection.immutable import scala.reflect.{ classTag, ClassTag } -import akka.stream.impl.Stages.SymbolicStage import akka.japi.function /** @@ -53,12 +51,28 @@ final case class Attributes(attributeList: List[Attributes.Attribute] = Nil) { case None ⇒ default } + /** + * Java API: Get the first (least specific) attribute of a given `Class` or subclass thereof. + * If no such attribute exists the `default` value is returned. + */ + def getFirstAttribute[T <: Attribute](c: Class[T], default: T): T = + getFirstAttribute(c) match { + case Some(a) ⇒ a + case None ⇒ default + } + /** * Java API: Get the last (most specific) attribute of a given `Class` or subclass thereof. */ def getAttribute[T <: Attribute](c: Class[T]): Option[T] = Option(attributeList.foldLeft(null.asInstanceOf[T])((acc, attr) ⇒ if (c.isInstance(attr)) c.cast(attr) else acc)) + /** + * Java API: Get the first (least specific) attribute of a given `Class` or subclass thereof. + */ + def getFirstAttribute[T <: Attribute](c: Class[T]): Option[T] = + attributeList.find(c isInstance _).map(c cast _) + /** * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. * If no such attribute exists the `default` value is returned. @@ -66,12 +80,25 @@ final case class Attributes(attributeList: List[Attributes.Attribute] = Nil) { def get[T <: Attribute: ClassTag](default: T) = getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]], default) + /** + * Get the first (least specific) attribute of a given type parameter T `Class` or subclass thereof. + * If no such attribute exists the `default` value is returned. + */ + def getFirst[T <: Attribute: ClassTag](default: T) = + getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]], default) + /** * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. */ def get[T <: Attribute: ClassTag] = getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]]) + /** + * Get the first (least specific) attribute of a given type parameter T `Class` or subclass thereof. + */ + def getFirst[T <: Attribute: ClassTag] = + getFirstAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]]) + /** * Adds given attributes to the end of these attributes. */ diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index 9e631754f6..71c88bcf25 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -62,8 +62,7 @@ private[akka] case class ActorMaterializerImpl(system: ActorSystem, case InputBuffer(initial, max) ⇒ s.withInputBuffer(initial, max) case Dispatcher(dispatcher) ⇒ s.withDispatcher(dispatcher) case SupervisionStrategy(decider) ⇒ s.withSupervisionStrategy(decider) - case l: LogLevels ⇒ s - case Name(_) ⇒ s + case _ ⇒ s } } }