!htc #16494 provide content negotiation fallback to exception handlers

This commit is contained in:
2beaucoup 2014-12-10 12:35:50 +01:00 committed by André Rüdiger
parent a1d7c33663
commit ec7156698a
7 changed files with 45 additions and 6 deletions

View file

@ -13,6 +13,8 @@ import akka.http.model.japi.MediaRange;
public abstract class Accept extends akka.http.model.HttpHeader { public abstract class Accept extends akka.http.model.HttpHeader {
public abstract Iterable<MediaRange> getMediaRanges(); public abstract Iterable<MediaRange> getMediaRanges();
public abstract boolean acceptsAll();
public static Accept create(MediaRange... mediaRanges) { public static Accept create(MediaRange... mediaRanges) {
return new akka.http.model.headers.Accept(akka.http.model.japi.Util.<MediaRange, akka.http.model.MediaRange>convertArray(mediaRanges)); return new akka.http.model.headers.Accept(akka.http.model.japi.Util.<MediaRange, akka.http.model.MediaRange>convertArray(mediaRanges));
} }

View file

@ -22,6 +22,7 @@ sealed abstract class MediaRange extends japi.MediaRange with Renderable with Wi
def isMultipart = false def isMultipart = false
def isText = false def isText = false
def isVideo = false def isVideo = false
def isWildcard = mainType == "*"
/** /**
* Returns a copy of this instance with the params replaced by the given ones. * Returns a copy of this instance with the params replaced by the given ones.
@ -123,6 +124,7 @@ object MediaRanges extends ObjectRegistry[String, MediaRange] {
def matches(mediaType: MediaType) = true def matches(mediaType: MediaType) = true
def specimen = MediaTypes.`text/plain` def specimen = MediaTypes.`text/plain`
} }
val `*/*(minQ)` = `*/*`.withQValue(Float.MinPositiveValue)
val `application/*` = new PredefinedMediaRange("application/*") { val `application/*` = new PredefinedMediaRange("application/*") {
def matches(mediaType: MediaType) = mediaType.isApplication def matches(mediaType: MediaType) = mediaType.isApplication
override def isApplication = true override def isApplication = true

View file

@ -129,6 +129,7 @@ final case class Accept(mediaRanges: immutable.Seq[MediaRange]) extends japi.hea
import Accept.mediaRangesRenderer import Accept.mediaRangesRenderer
def renderValue[R <: Rendering](r: R): r.type = r ~~ mediaRanges def renderValue[R <: Rendering](r: R): r.type = r ~~ mediaRanges
protected def companion = Accept protected def companion = Accept
def acceptsAll = mediaRanges.exists(mr mr.isWildcard && mr.qValue > 0f)
/** Java API */ /** Java API */
def getMediaRanges: Iterable[japi.MediaRange] = mediaRanges.asJava def getMediaRanges: Iterable[japi.MediaRange] = mediaRanges.asJava

View file

@ -5,7 +5,8 @@
package akka.http.server package akka.http.server
package directives package directives
import akka.http.model.StatusCodes import akka.http.model.{ MediaTypes, MediaRanges, StatusCodes }
import akka.http.model.headers._
import scala.concurrent.Future import scala.concurrent.Future
@ -72,6 +73,27 @@ class ExecutionDirectivesSpec extends RoutingSpec {
responseAs[String] shouldEqual "There was an internal server error." responseAs[String] shouldEqual "There was an internal server error."
} }
} }
"always fall back to a default content type" in {
Get("/abc") ~> Accept(MediaTypes.`application/json`) ~>
get {
handleExceptions(handler) {
throw new RuntimeException
}
} ~> check {
status shouldEqual StatusCodes.InternalServerError
responseAs[String] shouldEqual "There was an internal server error."
}
Get("/abc") ~> Accept(MediaTypes.`text/xml`, MediaRanges.`*/*`.withQValue(0f)) ~>
get {
handleExceptions(handler) {
throw new RuntimeException
}
} ~> check {
status shouldEqual StatusCodes.InternalServerError
responseAs[String] shouldEqual "There was an internal server error."
}
}
} }
def exceptionShouldBeHandled(route: Route) = def exceptionShouldBeHandled(route: Route) =

View file

@ -112,5 +112,5 @@ trait RequestContext {
/** /**
* Removes a potentially existing Accept header from the request headers. * Removes a potentially existing Accept header from the request headers.
*/ */
def withContentNegotiationDisabled: RequestContext def withAcceptAll: RequestContext
} }

View file

@ -69,8 +69,20 @@ private[http] class RequestContextImpl(
override def mapUnmatchedPath(f: Uri.Path Uri.Path): RequestContext = override def mapUnmatchedPath(f: Uri.Path Uri.Path): RequestContext =
copy(unmatchedPath = f(unmatchedPath)) copy(unmatchedPath = f(unmatchedPath))
override def withContentNegotiationDisabled: RequestContext = override def withAcceptAll: RequestContext = request.header[headers.Accept] match {
copy(request = request.withHeaders(request.headers filterNot (_.isInstanceOf[headers.Accept]))) case Some(accept @ headers.Accept(mediaRanges)) if !accept.acceptsAll
mapRequest(_.mapHeaders(_.map {
case `accept`
val acceptAll =
if (mediaRanges.exists(_.isWildcard))
mediaRanges.map(mr if (mr.isWildcard) mr.withQValue(Float.MinPositiveValue) else mr)
else
mediaRanges :+ MediaRanges.`*/*(minQ)`
accept.copy(mediaRanges = acceptAll)
case x x
}))
case _ this
}
private def copy(request: HttpRequest = request, private def copy(request: HttpRequest = request,
unmatchedPath: Uri.Path = unmatchedPath, unmatchedPath: Uri.Path = unmatchedPath,

View file

@ -23,7 +23,7 @@ trait ExecutionDirectives {
ctx ctx
import ctx.executionContext import ctx.executionContext
def handleException: PartialFunction[Throwable, Future[RouteResult]] = def handleException: PartialFunction[Throwable, Future[RouteResult]] =
handler andThen (_(ctx.withContentNegotiationDisabled)) handler andThen (_(ctx.withAcceptAll))
try innerRouteBuilder(())(ctx).fast.recoverWith(handleException) try innerRouteBuilder(())(ctx).fast.recoverWith(handleException)
catch { catch {
case NonFatal(e) handleException.applyOrElse[Throwable, Future[RouteResult]](e, throw _) case NonFatal(e) handleException.applyOrElse[Throwable, Future[RouteResult]](e, throw _)
@ -42,7 +42,7 @@ trait ExecutionDirectives {
val errorMsg = "The RejectionHandler for %s must not itself produce rejections (received %s)!" val errorMsg = "The RejectionHandler for %s must not itself produce rejections (received %s)!"
recoverRejections(r sys.error(errorMsg.format(filteredRejections, r))) { recoverRejections(r sys.error(errorMsg.format(filteredRejections, r))) {
handler(filteredRejections) handler(filteredRejections)
}(ctx.withContentNegotiationDisabled) }(ctx.withAcceptAll)
} else FastFuture.successful(RouteResult.Rejected(filteredRejections)) } else FastFuture.successful(RouteResult.Rejected(filteredRejections))
} }
} }