Merge pull request #16195 from spray/w/15930-RangeDirectives

+htp #15930 import RangeDirectives from spray
This commit is contained in:
Björn Antonsson 2014-11-05 10:54:33 +01:00
commit 04385f4d91
8 changed files with 341 additions and 6 deletions

View file

@ -106,6 +106,13 @@ sealed trait ResponseEntity extends HttpEntity with japi.ResponseEntity {
/* An entity that can be used for requests, responses, and body parts */ /* An entity that can be used for requests, responses, and body parts */
sealed trait UniversalEntity extends japi.UniversalEntity with MessageEntity with BodyPartEntity { sealed trait UniversalEntity extends japi.UniversalEntity with MessageEntity with BodyPartEntity {
def withContentType(contentType: ContentType): UniversalEntity def withContentType(contentType: ContentType): UniversalEntity
def contentLength: Long
/**
* Transforms this' entities data bytes with a transformer that will produce exactly the number of bytes given as
* ``newContentLength``.
*/
def transformDataBytes(newContentLength: Long, transformer: () Transformer[ByteString, ByteString]): UniversalEntity
} }
object HttpEntity { object HttpEntity {
@ -141,6 +148,7 @@ object HttpEntity {
*/ */
final case class Strict(contentType: ContentType, data: ByteString) final case class Strict(contentType: ContentType, data: ByteString)
extends japi.HttpEntityStrict with UniversalEntity { extends japi.HttpEntityStrict with UniversalEntity {
def contentLength: Long = data.length
def isKnownEmpty: Boolean = data.isEmpty def isKnownEmpty: Boolean = data.isEmpty
@ -159,6 +167,12 @@ object HttpEntity {
Chunked(contentType, Source.failed(ex)) Chunked(contentType, Source.failed(ex))
} }
} }
override def transformDataBytes(newContentLength: Long, transformer: () Transformer[ByteString, ByteString]): UniversalEntity = {
val t = transformer()
val newData = (t.onNext(data) ++ t.onTermination(None)).join
assert(newData.length.toLong == newContentLength, s"Transformer didn't produce as much bytes (${newData.length}:'${newData.utf8String}') as claimed ($newContentLength)")
copy(data = newData)
}
def withContentType(contentType: ContentType): Strict = def withContentType(contentType: ContentType): Strict =
if (contentType == this.contentType) this else copy(contentType = contentType) if (contentType == this.contentType) this else copy(contentType = contentType)
@ -185,6 +199,8 @@ object HttpEntity {
HttpEntity.Chunked(contentType, chunks) HttpEntity.Chunked(contentType, chunks)
} }
override def transformDataBytes(newContentLength: Long, transformer: () Transformer[ByteString, ByteString]): UniversalEntity =
Default(contentType, newContentLength, data.transform("transformDataBytes-with-new-length-Default", transformer))
def withContentType(contentType: ContentType): Default = def withContentType(contentType: ContentType): Default =
if (contentType == this.contentType) this else copy(contentType = contentType) if (contentType == this.contentType) this else copy(contentType = contentType)

View file

@ -58,6 +58,46 @@ private[http] object StreamUtils {
override def onError(cause: scala.Throwable): Unit = throw f(cause) override def onError(cause: scala.Throwable): Unit = throw f(cause)
} }
def sliceBytesTransformer(start: Long, length: Long): Transformer[ByteString, ByteString] =
new Transformer[ByteString, ByteString] {
type State = Transformer[ByteString, ByteString]
def skipping = new State {
var toSkip = start
def onNext(element: ByteString): immutable.Seq[ByteString] =
if (element.length < toSkip) {
// keep skipping
toSkip -= element.length
Nil
} else {
become(taking(length))
// toSkip <= element.length <= Int.MaxValue
currentState.onNext(element.drop(toSkip.toInt))
}
}
def taking(initiallyRemaining: Long) = new State {
var remaining: Long = initiallyRemaining
def onNext(element: ByteString): immutable.Seq[ByteString] = {
val data = element.take(math.min(remaining, Int.MaxValue).toInt)
remaining -= data.size
if (remaining <= 0) become(finishing)
data :: Nil
}
}
def finishing = new State {
override def isComplete: Boolean = true
def onNext(element: ByteString): immutable.Seq[ByteString] =
throw new IllegalStateException("onNext called on complete stream")
}
var currentState: State = if (start > 0) skipping else taking(length)
def become(state: State): Unit = currentState = state
override def isComplete: Boolean = currentState.isComplete
def onNext(element: ByteString): immutable.Seq[ByteString] = currentState.onNext(element)
override def onTermination(e: Option[Throwable]): immutable.Seq[ByteString] = currentState.onTermination(e)
}
def mapEntityError(f: Throwable Throwable): RequestEntity RequestEntity = def mapEntityError(f: Throwable Throwable): RequestEntity RequestEntity =
_.transformDataBytes(() mapErrorTransformer(f)) _.transformDataBytes(() mapErrorTransformer(f))
} }

View file

@ -8,13 +8,14 @@ import language.implicitConversions
import language.higherKinds import language.higherKinds
import java.nio.charset.Charset import java.nio.charset.Charset
import com.typesafe.config.Config import com.typesafe.config.Config
import akka.stream.FlattenStrategy import akka.stream.{ FlowMaterializer, FlattenStrategy, Transformer }
import akka.stream.scaladsl.{ Flow, Source } import akka.stream.scaladsl.{ Flow, Source }
import scala.concurrent.Future
import scala.util.matching.Regex import scala.util.matching.Regex
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.util.ByteString import akka.util.ByteString
import akka.actor._ import akka.actor._
import akka.stream.Transformer import scala.collection.immutable
package object util { package object util {
private[http] val UTF8 = Charset.forName("UTF8") private[http] val UTF8 = Charset.forName("UTF8")
@ -53,7 +54,7 @@ package object util {
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
} }
private[http] implicit class SourceWithPrintEvent[T](val underlying: Source[T]) { private[http] implicit class EnhancedSource[T](val underlying: Source[T]) {
def printEvent(marker: String): Source[T] = def printEvent(marker: String): Source[T] =
underlying.transform("transform", underlying.transform("transform",
() new Transformer[T, T] { () new Transformer[T, T] {
@ -66,6 +67,14 @@ package object util {
Nil Nil
} }
}) })
/**
* Drain this stream into a Vector and provide it as a future value.
*
* FIXME: Should be part of akka-streams
*/
def collectAll(implicit materializer: FlowMaterializer): Future[immutable.Seq[T]] =
underlying.fold(Vector.empty[T])(_ :+ _)
} }
private[http] def errorLogger(log: LoggingAdapter, msg: String): Transformer[ByteString, ByteString] = private[http] def errorLogger(log: LoggingAdapter, msg: String): Transformer[ByteString, ByteString] =

View file

@ -0,0 +1,125 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import akka.http.model.StatusCodes._
import akka.http.model._
import akka.http.model.headers._
import akka.http.util._
import org.scalatest.{ Inside, Inspectors }
import scala.concurrent.Await
import scala.concurrent.duration._
class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
lazy val wrs =
mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) &
withRangeSupport
def bytes(length: Byte) = Array.tabulate[Byte](length)(_.toByte)
"The `withRangeSupport` directive" should {
def completeWithRangedBytes(length: Byte) = wrs(complete(bytes(length)))
"return an Accept-Ranges(bytes) header for GET requests" in {
Get() ~> { wrs { complete("any") } } ~> check {
headers should contain(`Accept-Ranges`(RangeUnits.Bytes))
}
}
"not return an Accept-Ranges(bytes) header for non-GET requests" in {
Put() ~> { wrs { complete("any") } } ~> check {
headers should not contain `Accept-Ranges`(RangeUnits.Bytes)
}
}
"return a Content-Range header for a ranged request with a single range" in {
Get() ~> addHeader(Range(ByteRange(0, 1))) ~> completeWithRangedBytes(10) ~> check {
headers should contain(`Content-Range`(ContentRange(0, 1, 10)))
status shouldEqual PartialContent
responseAs[Array[Byte]] shouldEqual bytes(2)
}
}
"return a partial response for a ranged request with a single range with undefined lastBytePosition" in {
Get() ~> addHeader(Range(ByteRange.fromOffset(5))) ~> completeWithRangedBytes(10) ~> check {
responseAs[Array[Byte]] shouldEqual Array[Byte](5, 6, 7, 8, 9)
}
}
"return a partial response for a ranged request with a single suffix range" in {
Get() ~> addHeader(Range(ByteRange.suffix(1))) ~> completeWithRangedBytes(10) ~> check {
responseAs[Array[Byte]] shouldEqual Array[Byte](9)
}
}
"return a partial response for a ranged request with a overlapping suffix range" in {
Get() ~> addHeader(Range(ByteRange.suffix(100))) ~> completeWithRangedBytes(10) ~> check {
responseAs[Array[Byte]] shouldEqual bytes(10)
}
}
"be transparent to non-GET requests" in {
Post() ~> addHeader(Range(ByteRange(1, 2))) ~> completeWithRangedBytes(5) ~> check {
responseAs[Array[Byte]] shouldEqual bytes(5)
}
}
"be transparent to non-200 responses" in {
Get() ~> addHeader(Range(ByteRange(1, 2))) ~> sealRoute(wrs(reject())) ~> check {
status == NotFound
headers.exists { case `Content-Range`(_, _) true; case _ false } shouldEqual false
}
}
"reject an unsatisfiable single range" in {
Get() ~> addHeader(Range(ByteRange(100, 200))) ~> completeWithRangedBytes(10) ~> check {
rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange(100, 200)), 10)
}
}
"reject an unsatisfiable single suffix range with length 0" in {
Get() ~> addHeader(Range(ByteRange.suffix(0))) ~> completeWithRangedBytes(42) ~> check {
rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange.suffix(0)), 42)
}
}
"return a mediaType of 'multipart/byteranges' for a ranged request with multiple ranges" in {
Get() ~> addHeader(Range(ByteRange(0, 10), ByteRange(0, 10))) ~> completeWithRangedBytes(10) ~> check {
mediaType.withParams(Map.empty) shouldEqual MediaTypes.`multipart/byteranges`
}
}
"return a 'multipart/byteranges' for a ranged request with multiple coalesced ranges with preserved order" in {
Get() ~> addHeader(Range(ByteRange(5, 10), ByteRange(0, 1), ByteRange(1, 2))) ~> {
wrs { complete("Some random and not super short entity.") }
} ~> check {
header[`Content-Range`] should be(None)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second)
parts.size shouldEqual 2
inside(parts(0)) {
case Multipart.ByteRanges.BodyPart(range, entity, unit, headers)
range shouldEqual ContentRange.Default(5, 10, Some(39))
unit shouldEqual RangeUnits.Bytes
Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "random"
}
inside(parts(1)) {
case Multipart.ByteRanges.BodyPart(range, entity, unit, headers)
range shouldEqual ContentRange.Default(0, 2, Some(39))
unit shouldEqual RangeUnits.Bytes
Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "Som"
}
}
}
"reject a request with too many requested ranges" in {
val ranges = (1 to 20).map(a ByteRange.fromOffset(a))
Get() ~> addHeader(Range(ranges)) ~> completeWithRangedBytes(100) ~> check {
rejection shouldEqual TooManyRangesRejection(10)
}
}
}
}

View file

@ -12,4 +12,18 @@ akka.http.routing {
# and (probably) enabled for internal or non-browser APIs # and (probably) enabled for internal or non-browser APIs
# (Note that akka-http will always produce log messages containing the full error details) # (Note that akka-http will always produce log messages containing the full error details)
verbose-error-messages = off verbose-error-messages = off
# The maximum size between two requested ranges. Ranges with less space in between will be coalesced.
#
# When multiple ranges are requested, a server may coalesce any of the ranges that overlap or that are separated
# by a gap that is smaller than the overhead of sending multiple parts, regardless of the order in which the
# corresponding byte-range-spec appeared in the received Range header field. Since the typical overhead between
# parts of a multipart/byteranges payload is around 80 bytes, depending on the selected representation's
# media type and the chosen boundary parameter length, it can be less efficient to transfer many small
# disjoint parts than it is to transfer the entire selected representation.
range-coalescing-threshold = 80
# The maximum number of allowed ranges per request.
# Requests with more ranges will be rejected due to DOS suspicion.
range-count-limit = 16
} }

View file

@ -27,7 +27,7 @@ trait Directives extends RouteConcatenation
with MiscDirectives with MiscDirectives
with ParameterDirectives with ParameterDirectives
with PathDirectives with PathDirectives
//with RangeDirectives with RangeDirectives
with RespondWithDirectives with RespondWithDirectives
with RouteDirectives with RouteDirectives
with SchemeDirectives with SchemeDirectives

View file

@ -8,11 +8,16 @@ import com.typesafe.config.Config
import akka.actor.ActorRefFactory import akka.actor.ActorRefFactory
import akka.http.util._ import akka.http.util._
case class RoutingSettings(verboseErrorMessages: Boolean) case class RoutingSettings(
verboseErrorMessages: Boolean,
rangeCountLimit: Int,
rangeCoalescingThreshold: Long)
object RoutingSettings extends SettingsCompanion[RoutingSettings]("akka.http.routing") { object RoutingSettings extends SettingsCompanion[RoutingSettings]("akka.http.routing") {
def fromSubConfig(c: Config) = apply( def fromSubConfig(c: Config) = apply(
c getBoolean "verbose-error-messages") c getBoolean "verbose-error-messages",
c getInt "range-count-limit",
c getBytes "range-coalescing-threshold")
implicit def default(implicit refFactory: ActorRefFactory) = implicit def default(implicit refFactory: ActorRefFactory) =
apply(actorSystem) apply(actorSystem)

View file

@ -0,0 +1,126 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import akka.http.model.StatusCodes._
import akka.http.model._
import akka.http.model.headers._
import akka.http.server.RouteResult.Complete
import akka.http.util._
import akka.stream.scaladsl.Source
import scala.collection.immutable
trait RangeDirectives {
import akka.http.server.directives.BasicDirectives._
import akka.http.server.directives.RouteDirectives._
/**
* Answers GET requests with an `Accept-Ranges: bytes` header and converts HttpResponses coming back from its inner
* route into partial responses if the initial request contained a valid `Range` request header. The requested
* byte-ranges may be coalesced.
* This directive is transparent to non-GET requests
* Rejects requests with unsatisfiable ranges `UnsatisfiableRangeRejection`.
* Rejects requests with too many expected ranges.
*
* Note: if you want to combine this directive with `conditional(...)` you need to put
* it on the *inside* of the `conditional(...)` directive, i.e. `conditional(...)` must be
* on a higher level in your route structure in order to function correctly.
*
* @see https://tools.ietf.org/html/rfc7233
*/
def withRangeSupport: Directive0 =
extractRequestContext.flatMap { ctx
val settings = ctx.settings
implicit val log = ctx.log
import settings.{ rangeCountLimit, rangeCoalescingThreshold }
class IndexRange(val start: Long, val end: Long) {
def length = end - start
def apply(entity: UniversalEntity): UniversalEntity = entity.transformDataBytes(length, () StreamUtils.sliceBytesTransformer(start, length))
def distance(other: IndexRange) = mergedEnd(other) - mergedStart(other) - (length + other.length)
def mergeWith(other: IndexRange) = new IndexRange(mergedStart(other), mergedEnd(other))
def contentRange(entityLength: Long) = ContentRange(start, end - 1, entityLength)
private def mergedStart(other: IndexRange) = math.min(start, other.start)
private def mergedEnd(other: IndexRange) = math.max(end, other.end)
}
def indexRange(entityLength: Long)(range: ByteRange): IndexRange =
range match {
case ByteRange.Slice(start, end) new IndexRange(start, math.min(end + 1, entityLength))
case ByteRange.FromOffset(first) new IndexRange(first, entityLength)
case ByteRange.Suffix(suffixLength) new IndexRange(math.max(0, entityLength - suffixLength), entityLength)
}
// See comment of the `range-coalescing-threshold` setting in `reference.conf` for the rationale of this behavior.
def coalesceRanges(iRanges: Seq[IndexRange]): Seq[IndexRange] =
iRanges.foldLeft(Seq.empty[IndexRange]) { (acc, iRange)
val (mergeCandidates, otherCandidates) = acc.partition(_.distance(iRange) <= rangeCoalescingThreshold)
val merged = mergeCandidates.foldLeft(iRange)(_ mergeWith _)
otherCandidates :+ merged
}
def multipartRanges(ranges: Seq[ByteRange], entity: UniversalEntity): Multipart.ByteRanges = {
val length = entity.contentLength
val iRanges: Seq[IndexRange] = ranges.map(indexRange(length))
val bodyParts = coalesceRanges(iRanges).map(ir Multipart.ByteRanges.BodyPart(ir.contentRange(length), ir(entity)))
Multipart.ByteRanges(Source(bodyParts.toVector))
}
def rangeResponse(range: ByteRange, entity: UniversalEntity, length: Long, headers: immutable.Seq[HttpHeader]) = {
val aiRange = indexRange(length)(range)
HttpResponse(PartialContent, `Content-Range`(aiRange.contentRange(length)) +: headers, aiRange(entity))
}
def satisfiable(entityLength: Long)(range: ByteRange): Boolean =
range match {
case ByteRange.Slice(firstPos, _) firstPos < entityLength
case ByteRange.FromOffset(firstPos) firstPos < entityLength
case ByteRange.Suffix(length) length > 0
}
def universal(entity: HttpEntity): Option[UniversalEntity] = entity match {
case u: UniversalEntity Some(u)
case _ None
}
def applyRanges(ranges: Seq[ByteRange]): Directive0 =
extractRequestContext.flatMap { ctx
import ctx.executionContext
mapRouteResultWithPF {
case Complete(HttpResponse(OK, headers, entity, protocol))
universal(entity) match {
case Some(entity)
val length = entity.contentLength
ranges.filter(satisfiable(length)) match {
case Nil ctx.reject(UnsatisfiableRangeRejection(ranges, length))
case Seq(satisfiableRange) ctx.complete(rangeResponse(satisfiableRange, entity, length, headers))
case satisfiableRanges
ctx.complete(PartialContent, headers, multipartRanges(satisfiableRanges, entity))
}
case None
// Ranges not supported for Chunked or CloseDelimited responses
ctx.reject(UnsatisfiableRangeRejection(ranges, -1)) // FIXME: provide better error
}
}
}
def rangeHeaderOfGetRequests(ctx: RequestContext): Option[Range] =
if (ctx.request.method == HttpMethods.GET) ctx.request.header[Range] else None
extract(rangeHeaderOfGetRequests).flatMap {
case Some(Range(RangeUnits.Bytes, ranges))
if (ranges.size <= rangeCountLimit) applyRanges(ranges) & RangeDirectives.respondWithAcceptByteRangesHeader
else reject(TooManyRangesRejection(rangeCountLimit))
case _ MethodDirectives.get & RangeDirectives.respondWithAcceptByteRangesHeader | pass
}
}
}
object RangeDirectives extends RangeDirectives {
private val respondWithAcceptByteRangesHeader: Directive0 =
RespondWithDirectives.respondWithHeader(`Accept-Ranges`(RangeUnits.Bytes))
}