diff --git a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala index 015d8b9095..b199624d59 100644 --- a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala @@ -106,6 +106,13 @@ sealed trait ResponseEntity extends HttpEntity with japi.ResponseEntity { /* An entity that can be used for requests, responses, and body parts */ sealed trait UniversalEntity extends japi.UniversalEntity with MessageEntity with BodyPartEntity { def withContentType(contentType: ContentType): UniversalEntity + def contentLength: Long + + /** + * Transforms this' entities data bytes with a transformer that will produce exactly the number of bytes given as + * ``newContentLength``. + */ + def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity } object HttpEntity { @@ -141,6 +148,7 @@ object HttpEntity { */ final case class Strict(contentType: ContentType, data: ByteString) extends japi.HttpEntityStrict with UniversalEntity { + def contentLength: Long = data.length def isKnownEmpty: Boolean = data.isEmpty @@ -159,6 +167,12 @@ object HttpEntity { Chunked(contentType, Source.failed(ex)) } } + override def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity = { + val t = transformer() + val newData = (t.onNext(data) ++ t.onTermination(None)).join + assert(newData.length.toLong == newContentLength, s"Transformer didn't produce as much bytes (${newData.length}:'${newData.utf8String}') as claimed ($newContentLength)") + copy(data = newData) + } def withContentType(contentType: ContentType): Strict = if (contentType == this.contentType) this else copy(contentType = contentType) @@ -185,6 +199,8 @@ object HttpEntity { HttpEntity.Chunked(contentType, chunks) } + override def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity = + Default(contentType, newContentLength, data.transform("transformDataBytes-with-new-length-Default", transformer)) def withContentType(contentType: ContentType): Default = if (contentType == this.contentType) this else copy(contentType = contentType) diff --git a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala index d9a8cbba18..432d13740b 100644 --- a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala @@ -58,6 +58,46 @@ private[http] object StreamUtils { override def onError(cause: scala.Throwable): Unit = throw f(cause) } + def sliceBytesTransformer(start: Long, length: Long): Transformer[ByteString, ByteString] = + new Transformer[ByteString, ByteString] { + type State = Transformer[ByteString, ByteString] + + def skipping = new State { + var toSkip = start + def onNext(element: ByteString): immutable.Seq[ByteString] = + if (element.length < toSkip) { + // keep skipping + toSkip -= element.length + Nil + } else { + become(taking(length)) + // toSkip <= element.length <= Int.MaxValue + currentState.onNext(element.drop(toSkip.toInt)) + } + } + def taking(initiallyRemaining: Long) = new State { + var remaining: Long = initiallyRemaining + def onNext(element: ByteString): immutable.Seq[ByteString] = { + val data = element.take(math.min(remaining, Int.MaxValue).toInt) + remaining -= data.size + if (remaining <= 0) become(finishing) + data :: Nil + } + } + def finishing = new State { + override def isComplete: Boolean = true + def onNext(element: ByteString): immutable.Seq[ByteString] = + throw new IllegalStateException("onNext called on complete stream") + } + + var currentState: State = if (start > 0) skipping else taking(length) + def become(state: State): Unit = currentState = state + + override def isComplete: Boolean = currentState.isComplete + def onNext(element: ByteString): immutable.Seq[ByteString] = currentState.onNext(element) + override def onTermination(e: Option[Throwable]): immutable.Seq[ByteString] = currentState.onTermination(e) + } + def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity = _.transformDataBytes(() ⇒ mapErrorTransformer(f)) } diff --git a/akka-http-core/src/main/scala/akka/http/util/package.scala b/akka-http-core/src/main/scala/akka/http/util/package.scala index 82e52c7070..9e427bbcf1 100644 --- a/akka-http-core/src/main/scala/akka/http/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/util/package.scala @@ -8,13 +8,14 @@ import language.implicitConversions import language.higherKinds import java.nio.charset.Charset import com.typesafe.config.Config -import akka.stream.FlattenStrategy +import akka.stream.{ FlowMaterializer, FlattenStrategy, Transformer } import akka.stream.scaladsl.{ Flow, Source } +import scala.concurrent.Future import scala.util.matching.Regex import akka.event.LoggingAdapter import akka.util.ByteString import akka.actor._ -import akka.stream.Transformer +import scala.collection.immutable package object util { private[http] val UTF8 = Charset.forName("UTF8") @@ -53,7 +54,7 @@ package object util { .flatten(FlattenStrategy.concat) } - private[http] implicit class SourceWithPrintEvent[T](val underlying: Source[T]) { + private[http] implicit class EnhancedSource[T](val underlying: Source[T]) { def printEvent(marker: String): Source[T] = underlying.transform("transform", () ⇒ new Transformer[T, T] { @@ -66,6 +67,14 @@ package object util { Nil } }) + + /** + * Drain this stream into a Vector and provide it as a future value. + * + * FIXME: Should be part of akka-streams + */ + def collectAll(implicit materializer: FlowMaterializer): Future[immutable.Seq[T]] = + underlying.fold(Vector.empty[T])(_ :+ _) } private[http] def errorLogger(log: LoggingAdapter, msg: String): Transformer[ByteString, ByteString] = diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala new file mode 100644 index 0000000000..809133ddf9 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model.StatusCodes._ +import akka.http.model._ +import akka.http.model.headers._ +import akka.http.util._ +import org.scalatest.{ Inside, Inspectors } + +import scala.concurrent.Await +import scala.concurrent.duration._ + +class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { + lazy val wrs = + mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) & + withRangeSupport + + def bytes(length: Byte) = Array.tabulate[Byte](length)(_.toByte) + + "The `withRangeSupport` directive" should { + def completeWithRangedBytes(length: Byte) = wrs(complete(bytes(length))) + + "return an Accept-Ranges(bytes) header for GET requests" in { + Get() ~> { wrs { complete("any") } } ~> check { + headers should contain(`Accept-Ranges`(RangeUnits.Bytes)) + } + } + + "not return an Accept-Ranges(bytes) header for non-GET requests" in { + Put() ~> { wrs { complete("any") } } ~> check { + headers should not contain `Accept-Ranges`(RangeUnits.Bytes) + } + } + + "return a Content-Range header for a ranged request with a single range" in { + Get() ~> addHeader(Range(ByteRange(0, 1))) ~> completeWithRangedBytes(10) ~> check { + headers should contain(`Content-Range`(ContentRange(0, 1, 10))) + status shouldEqual PartialContent + responseAs[Array[Byte]] shouldEqual bytes(2) + } + } + + "return a partial response for a ranged request with a single range with undefined lastBytePosition" in { + Get() ~> addHeader(Range(ByteRange.fromOffset(5))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual Array[Byte](5, 6, 7, 8, 9) + } + } + + "return a partial response for a ranged request with a single suffix range" in { + Get() ~> addHeader(Range(ByteRange.suffix(1))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual Array[Byte](9) + } + } + + "return a partial response for a ranged request with a overlapping suffix range" in { + Get() ~> addHeader(Range(ByteRange.suffix(100))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual bytes(10) + } + } + + "be transparent to non-GET requests" in { + Post() ~> addHeader(Range(ByteRange(1, 2))) ~> completeWithRangedBytes(5) ~> check { + responseAs[Array[Byte]] shouldEqual bytes(5) + } + } + + "be transparent to non-200 responses" in { + Get() ~> addHeader(Range(ByteRange(1, 2))) ~> sealRoute(wrs(reject())) ~> check { + status == NotFound + headers.exists { case `Content-Range`(_, _) ⇒ true; case _ ⇒ false } shouldEqual false + } + } + + "reject an unsatisfiable single range" in { + Get() ~> addHeader(Range(ByteRange(100, 200))) ~> completeWithRangedBytes(10) ~> check { + rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange(100, 200)), 10) + } + } + + "reject an unsatisfiable single suffix range with length 0" in { + Get() ~> addHeader(Range(ByteRange.suffix(0))) ~> completeWithRangedBytes(42) ~> check { + rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange.suffix(0)), 42) + } + } + + "return a mediaType of 'multipart/byteranges' for a ranged request with multiple ranges" in { + Get() ~> addHeader(Range(ByteRange(0, 10), ByteRange(0, 10))) ~> completeWithRangedBytes(10) ~> check { + mediaType.withParams(Map.empty) shouldEqual MediaTypes.`multipart/byteranges` + } + } + + "return a 'multipart/byteranges' for a ranged request with multiple coalesced ranges with preserved order" in { + Get() ~> addHeader(Range(ByteRange(5, 10), ByteRange(0, 1), ByteRange(1, 2))) ~> { + wrs { complete("Some random and not super short entity.") } + } ~> check { + header[`Content-Range`] should be(None) + val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second) + parts.size shouldEqual 2 + inside(parts(0)) { + case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) ⇒ + range shouldEqual ContentRange.Default(5, 10, Some(39)) + unit shouldEqual RangeUnits.Bytes + Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "random" + } + inside(parts(1)) { + case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) ⇒ + range shouldEqual ContentRange.Default(0, 2, Some(39)) + unit shouldEqual RangeUnits.Bytes + Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "Som" + } + } + } + + "reject a request with too many requested ranges" in { + val ranges = (1 to 20).map(a ⇒ ByteRange.fromOffset(a)) + Get() ~> addHeader(Range(ranges)) ~> completeWithRangedBytes(100) ~> check { + rejection shouldEqual TooManyRangesRejection(10) + } + } + } +} diff --git a/akka-http/src/main/resources/reference.conf b/akka-http/src/main/resources/reference.conf index 4277502f06..927aa7a5c9 100644 --- a/akka-http/src/main/resources/reference.conf +++ b/akka-http/src/main/resources/reference.conf @@ -12,4 +12,18 @@ akka.http.routing { # and (probably) enabled for internal or non-browser APIs # (Note that akka-http will always produce log messages containing the full error details) verbose-error-messages = off + + # The maximum size between two requested ranges. Ranges with less space in between will be coalesced. + # + # When multiple ranges are requested, a server may coalesce any of the ranges that overlap or that are separated + # by a gap that is smaller than the overhead of sending multiple parts, regardless of the order in which the + # corresponding byte-range-spec appeared in the received Range header field. Since the typical overhead between + # parts of a multipart/byteranges payload is around 80 bytes, depending on the selected representation's + # media type and the chosen boundary parameter length, it can be less efficient to transfer many small + # disjoint parts than it is to transfer the entire selected representation. + range-coalescing-threshold = 80 + + # The maximum number of allowed ranges per request. + # Requests with more ranges will be rejected due to DOS suspicion. + range-count-limit = 16 } diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index 58eadaa825..a025af73ab 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -27,7 +27,7 @@ trait Directives extends RouteConcatenation with MiscDirectives with ParameterDirectives with PathDirectives - //with RangeDirectives + with RangeDirectives with RespondWithDirectives with RouteDirectives with SchemeDirectives diff --git a/akka-http/src/main/scala/akka/http/server/RoutingSettings.scala b/akka-http/src/main/scala/akka/http/server/RoutingSettings.scala index 84cf40cfc8..e65a2baea0 100644 --- a/akka-http/src/main/scala/akka/http/server/RoutingSettings.scala +++ b/akka-http/src/main/scala/akka/http/server/RoutingSettings.scala @@ -8,11 +8,16 @@ import com.typesafe.config.Config import akka.actor.ActorRefFactory import akka.http.util._ -case class RoutingSettings(verboseErrorMessages: Boolean) +case class RoutingSettings( + verboseErrorMessages: Boolean, + rangeCountLimit: Int, + rangeCoalescingThreshold: Long) object RoutingSettings extends SettingsCompanion[RoutingSettings]("akka.http.routing") { def fromSubConfig(c: Config) = apply( - c getBoolean "verbose-error-messages") + c getBoolean "verbose-error-messages", + c getInt "range-count-limit", + c getBytes "range-coalescing-threshold") implicit def default(implicit refFactory: ActorRefFactory) = apply(actorSystem) diff --git a/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala new file mode 100644 index 0000000000..4990f0f083 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model.StatusCodes._ +import akka.http.model._ +import akka.http.model.headers._ +import akka.http.server.RouteResult.Complete +import akka.http.util._ +import akka.stream.scaladsl.Source + +import scala.collection.immutable + +trait RangeDirectives { + import akka.http.server.directives.BasicDirectives._ + import akka.http.server.directives.RouteDirectives._ + + /** + * Answers GET requests with an `Accept-Ranges: bytes` header and converts HttpResponses coming back from its inner + * route into partial responses if the initial request contained a valid `Range` request header. The requested + * byte-ranges may be coalesced. + * This directive is transparent to non-GET requests + * Rejects requests with unsatisfiable ranges `UnsatisfiableRangeRejection`. + * Rejects requests with too many expected ranges. + * + * Note: if you want to combine this directive with `conditional(...)` you need to put + * it on the *inside* of the `conditional(...)` directive, i.e. `conditional(...)` must be + * on a higher level in your route structure in order to function correctly. + * + * @see https://tools.ietf.org/html/rfc7233 + */ + def withRangeSupport: Directive0 = + extractRequestContext.flatMap { ctx ⇒ + val settings = ctx.settings + implicit val log = ctx.log + import settings.{ rangeCountLimit, rangeCoalescingThreshold } + + class IndexRange(val start: Long, val end: Long) { + def length = end - start + def apply(entity: UniversalEntity): UniversalEntity = entity.transformDataBytes(length, () ⇒ StreamUtils.sliceBytesTransformer(start, length)) + def distance(other: IndexRange) = mergedEnd(other) - mergedStart(other) - (length + other.length) + def mergeWith(other: IndexRange) = new IndexRange(mergedStart(other), mergedEnd(other)) + def contentRange(entityLength: Long) = ContentRange(start, end - 1, entityLength) + private def mergedStart(other: IndexRange) = math.min(start, other.start) + private def mergedEnd(other: IndexRange) = math.max(end, other.end) + } + + def indexRange(entityLength: Long)(range: ByteRange): IndexRange = + range match { + case ByteRange.Slice(start, end) ⇒ new IndexRange(start, math.min(end + 1, entityLength)) + case ByteRange.FromOffset(first) ⇒ new IndexRange(first, entityLength) + case ByteRange.Suffix(suffixLength) ⇒ new IndexRange(math.max(0, entityLength - suffixLength), entityLength) + } + + // See comment of the `range-coalescing-threshold` setting in `reference.conf` for the rationale of this behavior. + def coalesceRanges(iRanges: Seq[IndexRange]): Seq[IndexRange] = + iRanges.foldLeft(Seq.empty[IndexRange]) { (acc, iRange) ⇒ + val (mergeCandidates, otherCandidates) = acc.partition(_.distance(iRange) <= rangeCoalescingThreshold) + val merged = mergeCandidates.foldLeft(iRange)(_ mergeWith _) + otherCandidates :+ merged + } + + def multipartRanges(ranges: Seq[ByteRange], entity: UniversalEntity): Multipart.ByteRanges = { + val length = entity.contentLength + val iRanges: Seq[IndexRange] = ranges.map(indexRange(length)) + val bodyParts = coalesceRanges(iRanges).map(ir ⇒ Multipart.ByteRanges.BodyPart(ir.contentRange(length), ir(entity))) + Multipart.ByteRanges(Source(bodyParts.toVector)) + } + + def rangeResponse(range: ByteRange, entity: UniversalEntity, length: Long, headers: immutable.Seq[HttpHeader]) = { + val aiRange = indexRange(length)(range) + HttpResponse(PartialContent, `Content-Range`(aiRange.contentRange(length)) +: headers, aiRange(entity)) + } + + def satisfiable(entityLength: Long)(range: ByteRange): Boolean = + range match { + case ByteRange.Slice(firstPos, _) ⇒ firstPos < entityLength + case ByteRange.FromOffset(firstPos) ⇒ firstPos < entityLength + case ByteRange.Suffix(length) ⇒ length > 0 + } + def universal(entity: HttpEntity): Option[UniversalEntity] = entity match { + case u: UniversalEntity ⇒ Some(u) + case _ ⇒ None + } + + def applyRanges(ranges: Seq[ByteRange]): Directive0 = + extractRequestContext.flatMap { ctx ⇒ + import ctx.executionContext + mapRouteResultWithPF { + case Complete(HttpResponse(OK, headers, entity, protocol)) ⇒ + universal(entity) match { + case Some(entity) ⇒ + val length = entity.contentLength + ranges.filter(satisfiable(length)) match { + case Nil ⇒ ctx.reject(UnsatisfiableRangeRejection(ranges, length)) + case Seq(satisfiableRange) ⇒ ctx.complete(rangeResponse(satisfiableRange, entity, length, headers)) + case satisfiableRanges ⇒ + ctx.complete(PartialContent, headers, multipartRanges(satisfiableRanges, entity)) + } + case None ⇒ + // Ranges not supported for Chunked or CloseDelimited responses + ctx.reject(UnsatisfiableRangeRejection(ranges, -1)) // FIXME: provide better error + } + } + } + + def rangeHeaderOfGetRequests(ctx: RequestContext): Option[Range] = + if (ctx.request.method == HttpMethods.GET) ctx.request.header[Range] else None + + extract(rangeHeaderOfGetRequests).flatMap { + case Some(Range(RangeUnits.Bytes, ranges)) ⇒ + if (ranges.size <= rangeCountLimit) applyRanges(ranges) & RangeDirectives.respondWithAcceptByteRangesHeader + else reject(TooManyRangesRejection(rangeCountLimit)) + case _ ⇒ MethodDirectives.get & RangeDirectives.respondWithAcceptByteRangesHeader | pass + } + } +} + +object RangeDirectives extends RangeDirectives { + private val respondWithAcceptByteRangesHeader: Directive0 = + RespondWithDirectives.respondWithHeader(`Accept-Ranges`(RangeUnits.Bytes)) +} +