!htp #16835 Refactor RejectionHandler infrastructure for cleanliness and independence from rejection ordering

This commit is contained in:
Mathias 2015-02-26 15:08:06 +01:00
parent a4f4cb298a
commit 9c3124f344
7 changed files with 220 additions and 119 deletions

View file

@ -18,10 +18,12 @@ object MyRejectionHandler {
import StatusCodes._
import Directives._
implicit val myRejectionHandler = RejectionHandler {
case MissingCookieRejection(cookieName) :: _ =>
complete(HttpResponse(BadRequest, entity = "No cookies, no service!!!"))
}
implicit val myRejectionHandler = RejectionHandler.newBuilder()
.handle {
case MissingCookieRejection(cookieName) =>
complete(HttpResponse(BadRequest, entity = "No cookies, no service!!!"))
}
.result()
object MyApp {
implicit val system = ActorSystem()

View file

@ -29,10 +29,9 @@ class ExecutionDirectivesExamplesSpec extends RoutingSpec {
}
}
"handleRejections" in {
val totallyMissingHandler = RejectionHandler {
case Nil /* secret code for path not found */ =>
complete(StatusCodes.NotFound, "Oh man, what you are looking for is long gone.")
}
val totallyMissingHandler = RejectionHandler.newBuilder()
.handleNotFound { complete(StatusCodes.NotFound, "Oh man, what you are looking for is long gone.") }
.result()
val route =
pathPrefix("handled") {
handleRejections(totallyMissingHandler) {

View file

@ -163,5 +163,16 @@ class BasicRouteSpecs extends RoutingSpec {
status shouldEqual StatusCodes.InternalServerError
}
}
"always prioritize MethodRejections over AuthorizationFailedRejections" in {
Get("/abc") ~> Route.seal {
post { completeOk } ~
authorize(false) { completeOk }
} ~> check { status shouldEqual StatusCodes.MethodNotAllowed }
Get("/abc") ~> Route.seal {
authorize(false) { completeOk } ~
post { completeOk }
} ~> check { status shouldEqual StatusCodes.MethodNotAllowed }
}
}
}

View file

@ -4,6 +4,8 @@
package akka.http.server
import scala.annotation.tailrec
import scala.reflect.ClassTag
import scala.collection.immutable
import scala.concurrent.ExecutionContext
import akka.http.model._
@ -12,114 +14,203 @@ import headers._
import directives.RouteDirectives._
import AuthenticationFailedRejection._
trait RejectionHandler extends RejectionHandler.PF {
def isDefault: Boolean
trait RejectionHandler extends (immutable.Seq[Rejection] Option[Route]) { self
import RejectionHandler._
/**
* Creates a new [[RejectionHandler]] which uses the given one as fallback for this one.
*/
def withFallback(that: RejectionHandler): RejectionHandler =
(this, that) match {
case (a: BuiltRejectionHandler, b: BuiltRejectionHandler)
new BuiltRejectionHandler(a.cases ++ b.cases, a.notFound orElse b.notFound, a.isSealed || b.isSealed)
case _ new RejectionHandler {
def apply(rejections: immutable.Seq[Rejection]): Option[Route] =
self(rejections) orElse that(rejections)
}
}
/**
* "Seals" this handler by attaching a default handler as fallback if necessary.
*/
def seal(implicit ec: ExecutionContext): RejectionHandler =
this match {
case x: BuiltRejectionHandler if x.isSealed x
case _ withFallback(default)
}
}
object RejectionHandler {
type PF = PartialFunction[immutable.Seq[Rejection], Route]
implicit def apply(pf: PF): RejectionHandler = apply(default = false)(pf)
/**
* Creates a new [[RejectionHandler]] builder.
*/
def newBuilder(): Builder = new Builder
private def apply(default: Boolean)(pf: PF): RejectionHandler =
new RejectionHandler {
def isDefault = default
def isDefinedAt(rejections: immutable.Seq[Rejection]) = pf.isDefinedAt(rejections)
def apply(rejections: immutable.Seq[Rejection]) = pf(rejections)
final class Builder {
private[this] val cases = new immutable.VectorBuilder[Handler]
private[this] var notFound: Option[Route] = None
private[this] var hasCatchAll: Boolean = false
/**
* Handles a single [[Rejection]] with the given partial function.
*/
def handle(pf: PartialFunction[Rejection, Route]): this.type = {
cases += CaseHandler(pf)
hasCatchAll ||= pf.isDefinedAt(PrivateRejection)
this
}
def default(implicit ec: ExecutionContext) = apply(default = true) {
case Nil complete(NotFound, "The requested resource could not be found.")
/**
* Handles several Rejections of the same type at the same time.
* The seq passed to the given function is guaranteed to be non-empty.
*/
def handleAll[T <: Rejection: ClassTag](f: immutable.Seq[T] Route): this.type = {
val runtimeClass = implicitly[ClassTag[T]].runtimeClass
cases += TypeHandler[T](runtimeClass, f)
hasCatchAll ||= runtimeClass == classOf[Rejection]
this
}
case rejections @ (AuthenticationFailedRejection(cause, _) +: _)
val rejectionMessage = cause match {
case CredentialsMissing "The resource requires authentication, which was not supplied with the request"
case CredentialsRejected "The supplied authentication is invalid"
}
val challenges = rejections.collect { case AuthenticationFailedRejection(_, challenge) challenge }
// Multiple challenges per WWW-Authenticate header are allowed per spec,
// however, it seems many browsers will ignore all challenges but the first.
// Therefore, multiple WWW-Authenticate headers are rendered, instead.
//
// See https://code.google.com/p/chromium/issues/detail?id=103220
// and https://bugzilla.mozilla.org/show_bug.cgi?id=669675
val authenticateHeaders = challenges.map(`WWW-Authenticate`(_))
/**
* Handles the special "not found" case using the given [[Route]].
*/
def handleNotFound(route: Route): this.type = {
notFound = Some(route)
this
}
complete(Unauthorized, authenticateHeaders, rejectionMessage)
case AuthorizationFailedRejection +: _
complete(Forbidden, "The supplied authentication is not authorized to access this resource")
case MalformedFormFieldRejection(name, msg, _) +: _
complete(BadRequest, "The form field '" + name + "' was malformed:\n" + msg)
case MalformedHeaderRejection(headerName, msg, _) +: _
complete(BadRequest, s"The value of HTTP header '$headerName' was malformed:\n" + msg)
case MalformedQueryParamRejection(name, msg, _) +: _
complete(BadRequest, "The query parameter '" + name + "' was malformed:\n" + msg)
case MalformedRequestContentRejection(msg, _) +: _
complete(BadRequest, "The request content was malformed:\n" + msg)
case rejections @ (MethodRejection(_) +: _)
val methods = rejections.collect { case MethodRejection(method) method }
complete(MethodNotAllowed, List(Allow(methods)), "HTTP method not allowed, supported methods: " + methods.mkString(", "))
case rejections @ (SchemeRejection(_) +: _)
val schemes = rejections.collect { case SchemeRejection(scheme) scheme }
complete(BadRequest, "Uri scheme not allowed, supported schemes: " + schemes.mkString(", "))
case MissingCookieRejection(cookieName) +: _
complete(BadRequest, "Request is missing required cookie '" + cookieName + '\'')
case MissingFormFieldRejection(fieldName) +: _
complete(BadRequest, "Request is missing required form field '" + fieldName + '\'')
case MissingHeaderRejection(headerName) +: _
complete(BadRequest, "Request is missing required HTTP header '" + headerName + '\'')
case MissingQueryParamRejection(paramName) +: _
complete(NotFound, "Request is missing required query parameter '" + paramName + '\'')
case RequestEntityExpectedRejection +: _
complete(BadRequest, "Request entity expected but not supplied")
case TooManyRangesRejection(_) +: _
complete(RequestedRangeNotSatisfiable, "Request contains too many ranges.")
case UnsatisfiableRangeRejection(unsatisfiableRanges, actualEntityLength) +: _
complete(RequestedRangeNotSatisfiable, List(`Content-Range`(ContentRange.Unsatisfiable(actualEntityLength))),
unsatisfiableRanges.mkString("None of the following requested Ranges were satisfiable:\n", "\n", ""))
case rejections @ (UnacceptedResponseContentTypeRejection(_) +: _)
val supported = rejections.flatMap {
case UnacceptedResponseContentTypeRejection(x) x
case _ Nil
}
complete(NotAcceptable, "Resource representation is only available with these Content-Types:\n" + supported.map(_.value).mkString("\n"))
case rejections @ (UnacceptedResponseEncodingRejection(_) +: _)
val supported = rejections.flatMap {
case UnacceptedResponseEncodingRejection(x) x
case _ Nil
}
complete(NotAcceptable, "Resource representation is only available with these Content-Encodings:\n" + supported.map(_.value).mkString("\n"))
case rejections @ (UnsupportedRequestContentTypeRejection(_) +: _)
val supported = rejections.collect { case UnsupportedRequestContentTypeRejection(x) x }
complete(UnsupportedMediaType, "The request's Content-Type is not supported. Expected:\n" + supported.mkString(" or "))
case rejections @ (UnsupportedRequestEncodingRejection(_) +: _)
val supported = rejections.collect { case UnsupportedRequestEncodingRejection(x) x }
complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported.map(_.value).mkString(" or "))
case ValidationRejection(msg, _) +: _
complete(BadRequest, msg)
case x +: _ sys.error("Unhandled rejection: " + x)
def result(): RejectionHandler =
new BuiltRejectionHandler(cases.result(), notFound, hasCatchAll && notFound.isDefined)
}
private sealed abstract class Handler
private final case class CaseHandler(pf: PartialFunction[Rejection, Route]) extends Handler
private final case class TypeHandler[T <: Rejection](
runtimeClass: Class[_], f: immutable.Seq[T] Route) extends Handler with PartialFunction[Rejection, T] {
def isDefinedAt(rejection: Rejection) = runtimeClass isInstance rejection
def apply(rejection: Rejection) = rejection.asInstanceOf[T]
}
private class BuiltRejectionHandler(val cases: Vector[Handler],
val notFound: Option[Route],
val isSealed: Boolean) extends RejectionHandler {
def apply(rejections: immutable.Seq[Rejection]): Option[Route] =
if (rejections.nonEmpty) {
@tailrec def rec(ix: Int): Option[Route] =
if (ix < cases.length) {
cases(ix) match {
case CaseHandler(pf)
val route = rejections collectFirst pf
if (route.isEmpty) rec(ix + 1) else route
case x @ TypeHandler(_, f)
val rejs = rejections collect x
if (rejs.isEmpty) rec(ix + 1) else Some(f(rejs))
}
} else None
rec(0)
} else notFound
}
/**
* Creates a new default [[RejectionHandler]] instance.
*/
def default(implicit ec: ExecutionContext) =
newBuilder()
.handleAll[SchemeRejection] { rejections
val schemes = rejections.map(_.supported).mkString(", ")
complete(BadRequest, "Uri scheme not allowed, supported schemes: " + schemes)
}
.handleAll[MethodRejection] { rejections
val methods = rejections.map(_.supported)
complete(MethodNotAllowed, List(Allow(methods)), "HTTP method not allowed, supported methods: " + methods.mkString(", "))
}
.handle {
case AuthorizationFailedRejection
complete(Forbidden, "The supplied authentication is not authorized to access this resource")
}
.handle {
case MalformedFormFieldRejection(name, msg, _)
complete(BadRequest, "The form field '" + name + "' was malformed:\n" + msg)
}
.handle {
case MalformedHeaderRejection(headerName, msg, _)
complete(BadRequest, s"The value of HTTP header '$headerName' was malformed:\n" + msg)
}
.handle {
case MalformedQueryParamRejection(name, msg, _)
complete(BadRequest, "The query parameter '" + name + "' was malformed:\n" + msg)
}
.handle {
case MalformedRequestContentRejection(msg, _)
complete(BadRequest, "The request content was malformed:\n" + msg)
}
.handle {
case MissingCookieRejection(cookieName)
complete(BadRequest, "Request is missing required cookie '" + cookieName + '\'')
}
.handle {
case MissingFormFieldRejection(fieldName)
complete(BadRequest, "Request is missing required form field '" + fieldName + '\'')
}
.handle {
case MissingHeaderRejection(headerName)
complete(BadRequest, "Request is missing required HTTP header '" + headerName + '\'')
}
.handle {
case MissingQueryParamRejection(paramName)
complete(NotFound, "Request is missing required query parameter '" + paramName + '\'')
}
.handle {
case RequestEntityExpectedRejection
complete(BadRequest, "Request entity expected but not supplied")
}
.handle {
case TooManyRangesRejection(_)
complete(RequestedRangeNotSatisfiable, "Request contains too many ranges.")
}
.handle {
case UnsatisfiableRangeRejection(unsatisfiableRanges, actualEntityLength)
complete(RequestedRangeNotSatisfiable, List(`Content-Range`(ContentRange.Unsatisfiable(actualEntityLength))),
unsatisfiableRanges.mkString("None of the following requested Ranges were satisfiable:\n", "\n", ""))
}
.handleAll[AuthenticationFailedRejection] { rejections
val rejectionMessage = rejections.head.cause match {
case CredentialsMissing "The resource requires authentication, which was not supplied with the request"
case CredentialsRejected "The supplied authentication is invalid"
}
// Multiple challenges per WWW-Authenticate header are allowed per spec,
// however, it seems many browsers will ignore all challenges but the first.
// Therefore, multiple WWW-Authenticate headers are rendered, instead.
//
// See https://code.google.com/p/chromium/issues/detail?id=103220
// and https://bugzilla.mozilla.org/show_bug.cgi?id=669675
val authenticateHeaders = rejections.map(r `WWW-Authenticate`(r.challenge))
complete(Unauthorized, authenticateHeaders, rejectionMessage)
}
.handleAll[UnacceptedResponseContentTypeRejection] { rejections
val supported = rejections.flatMap(_.supported)
complete(NotAcceptable, "Resource representation is only available with these Content-Types:\n" +
supported.map(_.value).mkString("\n"))
}
.handleAll[UnacceptedResponseEncodingRejection] { rejections
val supported = rejections.flatMap(_.supported)
complete(NotAcceptable, "Resource representation is only available with these Content-Encodings:\n" +
supported.map(_.value).mkString("\n"))
}
.handleAll[UnsupportedRequestContentTypeRejection] { rejections
val supported = rejections.flatMap(_.supported).mkString(" or ")
complete(UnsupportedMediaType, "The request's Content-Type is not supported. Expected:\n" + supported)
}
.handleAll[UnsupportedRequestEncodingRejection] { rejections
val supported = rejections.map(_.supported.value).mkString(" or ")
complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported)
}
.handle { case ValidationRejection(msg, _) complete(BadRequest, msg) }
.handle { case x sys.error("Unhandled rejection: " + x) }
.handleNotFound { complete(NotFound, "The requested resource could not be found.") }
.result()
/**
* Filters out all TransformationRejections from the given sequence and applies them (in order) to the
* remaining rejections.
@ -130,4 +221,6 @@ object RejectionHandler {
case (remaining, transformation) transformation.transform(remaining)
}
}
private object PrivateRejection extends Rejection
}

View file

@ -40,7 +40,7 @@ private[http] class RequestContextImpl(
}(executionContext)
override def reject(rejections: Rejection*): Future[RouteResult] =
FastFuture.successful(RouteResult.Rejected(rejections.toVector))
FastFuture.successful(RouteResult.Rejected(rejections.toList))
override def fail(error: Throwable): Future[RouteResult] =
FastFuture.failed(error)

View file

@ -25,11 +25,8 @@ object Route {
val sealedExceptionHandler =
if (exceptionHandler.isDefault) exceptionHandler
else exceptionHandler orElse ExceptionHandler.default(settings)
val sealedRejectionHandler =
if (rejectionHandler.isDefault) rejectionHandler
else rejectionHandler orElse RejectionHandler.default
handleExceptions(sealedExceptionHandler) {
handleRejections(sealedRejectionHandler) {
handleRejections(rejectionHandler.seal) {
route
}
}

View file

@ -5,11 +5,10 @@
package akka.http.server
package directives
import akka.http.util.FastFuture
import FastFuture._
import scala.concurrent.Future
import scala.util.control.NonFatal
import akka.http.util.FastFuture
import FastFuture._
trait ExecutionDirectives {
import BasicDirectives._
@ -38,12 +37,12 @@ trait ExecutionDirectives {
extractRequestContext flatMap { ctx
recoverRejectionsWith { rejections
val filteredRejections = RejectionHandler.applyTransformations(rejections)
if (handler isDefinedAt filteredRejections) {
val errorMsg = "The RejectionHandler for %s must not itself produce rejections (received %s)!"
recoverRejections(r sys.error(errorMsg.format(filteredRejections, r))) {
handler(filteredRejections)
}(ctx.withAcceptAll)
} else FastFuture.successful(RouteResult.Rejected(filteredRejections))
handler(filteredRejections) match {
case Some(route)
val errorMsg = "The RejectionHandler for %s must not itself produce rejections (received %s)!"
recoverRejections(r sys.error(errorMsg.format(filteredRejections, r)))(route)(ctx.withAcceptAll)
case None FastFuture.successful(RouteResult.Rejected(filteredRejections))
}
}
}
}