diff --git a/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala b/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala index 1636cc818b..8ba2d84642 100644 --- a/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala +++ b/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala @@ -63,7 +63,6 @@ trait RouteTestResultComponent { result = rr match { case RouteResult.Complete(response) ⇒ Some(Right(response)) case RouteResult.Rejected(rejections) ⇒ Some(Left(RejectionHandler.applyTransformations(rejections))) - case RouteResult.Failure(error) ⇒ sys.error("Route produced exception: " + error) } latch.countDown() } else failTest("Route completed/rejected more than once") diff --git a/akka-http-tests/src/test/scala/akka/http/server/BasicRouteSpecs.scala b/akka-http-tests/src/test/scala/akka/http/server/BasicRouteSpecs.scala index c016941012..309e52d149 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/BasicRouteSpecs.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/BasicRouteSpecs.scala @@ -4,7 +4,9 @@ package akka.http.server -import akka.http.model.HttpMethods._ +import akka.http.model +import model.HttpMethods._ +import model.StatusCodes import akka.http.server.PathMatchers.{ Segment, IntNumber } class BasicRouteSpecs extends RoutingSpec { @@ -74,6 +76,23 @@ class BasicRouteSpecs extends RoutingSpec { } ~> check { responseAs[String] shouldEqual "The cat 84 The cat" } } } + "Route disjunction" should { + "work" in { + val route = sealRoute((path("abc") | path("def")) { + completeOk + }) + + Get("/abc") ~> route ~> check { + status shouldEqual StatusCodes.OK + } + Get("/def") ~> route ~> check { + status shouldEqual StatusCodes.OK + } + Get("/ghi") ~> route ~> check { + status shouldEqual StatusCodes.NotFound + } + } + } "Case class extraction with Directive.as" should { "extract one argument" in { case class MyNumber(i: Int) diff --git a/akka-http/src/main/scala/akka/http/server/Directive.scala b/akka-http/src/main/scala/akka/http/server/Directive.scala index 4c8bc5087a..21c64b501d 100644 --- a/akka-http/src/main/scala/akka/http/server/Directive.scala +++ b/akka-http/src/main/scala/akka/http/server/Directive.scala @@ -4,6 +4,8 @@ package akka.http.server +import scala.concurrent.Future + import akka.http.server.directives.RouteDirectives import akka.http.server.util._ @@ -87,19 +89,22 @@ abstract class Directive[L](implicit val ev: Tuple[L]) { self ⇒ self.tapply { values ⇒ ctx ⇒ if (predicate(values)) f(values)(ctx) else ctx.reject(rejections: _*) } } + /** + * Creates a new directive that is able to recover from rejections that were produced by `this` Directive + * **before the inner route was applied**. + */ def recover[R >: L: Tuple](recovery: immutable.Seq[Rejection] ⇒ Directive[R]): Directive[R] = new Directive[R] { - def tapply(f: R ⇒ Route) = { ctx ⇒ - @volatile var rejectedFromInnerRoute = false - self.tapply({ list ⇒ c ⇒ rejectedFromInnerRoute = true; f(list)(c) }) { - ctx.withRejectionHandling { rejections ⇒ - if (rejectedFromInnerRoute) ctx.reject(rejections: _*) - else recovery(rejections).tapply(f)(ctx) - } - } + def tapply(inner: R ⇒ Route) = { ctx ⇒ + self.tapply(list ⇒ c ⇒ inner(list)(c))(ctx).recoverRejectionsWith { + rejections ⇒ recovery(rejections).tapply(inner)(ctx) + }(ctx.executionContext) } } + /** + * Variant of `recover` that only recovers from rejections handled by the given PartialFunction. + */ def recoverPF[R >: L: Tuple](recovery: PartialFunction[immutable.Seq[Rejection], Directive[R]]): Directive[R] = recover { rejections ⇒ if (recovery isDefinedAt rejections) recovery(rejections) @@ -143,4 +148,12 @@ object Directive { def filter(predicate: T ⇒ Boolean, rejections: Rejection*): Directive1[T] = underlying.tfilter({ case Tuple1(value) ⇒ predicate(value) }, rejections: _*) } + + /** + * Creates a Directive0 that maps the result produced by the inner route with the given function. + */ + def mapResult(f: (RequestContext, Future[RouteResult]) ⇒ Future[RouteResult]): Directive0 = + new Directive0 { + def tapply(inner: Unit ⇒ Route): Route = ctx ⇒ f(ctx, inner(())(ctx)) + } } diff --git a/akka-http/src/main/scala/akka/http/server/RequestContext.scala b/akka-http/src/main/scala/akka/http/server/RequestContext.scala index 6606fc2629..35b83ec445 100644 --- a/akka-http/src/main/scala/akka/http/server/RequestContext.scala +++ b/akka-http/src/main/scala/akka/http/server/RequestContext.scala @@ -4,7 +4,6 @@ package akka.http.server -import scala.collection.immutable import scala.concurrent.{ Future, ExecutionContext } import akka.event.LoggingAdapter import akka.http.marshalling.ToResponseMarshallable @@ -74,56 +73,6 @@ trait RequestContext { */ def withUnmatchedPathMapped(f: Uri.Path ⇒ Uri.Path): RequestContext - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withRouteResponseMapped(f: RouteResult ⇒ RouteResult): RequestContext - - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withRouteResponseMappedPF(f: PartialFunction[RouteResult, RouteResult]): RequestContext - - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withRouteResponseFlatMapped(f: RouteResult ⇒ Future[RouteResult]): RequestContext - - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withHttpResponseMapped(f: HttpResponse ⇒ HttpResponse): RequestContext - - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withHttpResponseEntityMapped(f: ResponseEntity ⇒ ResponseEntity): RequestContext - - /** - * Returns a copy of this context with the given response transformation function chained into the response chain. - */ - def withHttpResponseHeadersMapped(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): RequestContext - - /** - * Returns a copy of this context with the given rejection transformation function chained into the response chain. - */ - def withRejectionsMapped(f: immutable.Seq[Rejection] ⇒ immutable.Seq[Rejection]): RequestContext - - /** - * Returns a copy of this context with the given rejection handling function chained into the response chain. - */ - def withRejectionHandling(f: immutable.Seq[Rejection] ⇒ Future[RouteResult]): RequestContext - - /** - * Returns a copy of this context with the given exception handling function chained into the response chain. - */ - def withExceptionHandling(pf: PartialFunction[Throwable, Future[RouteResult]]): RequestContext - - /** - * Returns a copy of this context with the given function handling a part of the response space. - */ - def withRouteResponseHandling(pf: PartialFunction[RouteResult, Future[RouteResult]]): RequestContext - /** * Removes a potentially existing Accept header from the request headers. */ diff --git a/akka-http/src/main/scala/akka/http/server/RequestContextImpl.scala b/akka-http/src/main/scala/akka/http/server/RequestContextImpl.scala index 77daccf59c..8ab54782f5 100644 --- a/akka-http/src/main/scala/akka/http/server/RequestContextImpl.scala +++ b/akka-http/src/main/scala/akka/http/server/RequestContextImpl.scala @@ -4,11 +4,10 @@ package akka.http.server -import scala.collection.immutable import scala.concurrent.{ Future, ExecutionContext } import akka.event.LoggingAdapter import akka.http.marshalling.ToResponseMarshallable -import akka.http.util.{ FastFuture, identityFunc } +import akka.http.util.FastFuture import akka.http.model._ import FastFuture._ @@ -19,8 +18,7 @@ private[http] class RequestContextImpl( val request: HttpRequest, val unmatchedPath: Uri.Path, val executionContext: ExecutionContext, - val log: LoggingAdapter, - finish: RouteResult ⇒ Future[RouteResult] = FastFuture.successful) extends RequestContext { + val log: LoggingAdapter) extends RequestContext { def this(request: HttpRequest, log: LoggingAdapter)(implicit ec: ExecutionContext) = this(request, request.uri.path, ec, log) @@ -33,15 +31,13 @@ private[http] class RequestContextImpl( .fast.map(res ⇒ RouteResult.complete(res))(executionContext) .fast.recover { case RejectionError(rej) ⇒ RouteResult.rejected(rej :: Nil) - case error ⇒ RouteResult.failure(error) }(executionContext) - .fast.flatMap(finish)(executionContext) override def reject(rejections: Rejection*): Future[RouteResult] = - finish(RouteResult.rejected(rejections.toVector)) + FastFuture.successful(RouteResult.rejected(rejections.toVector)) override def fail(error: Throwable): Future[RouteResult] = - finish(RouteResult.failure(error)) + FastFuture.failed(error) override def withRequest(req: HttpRequest): RequestContext = copy(request = req) @@ -55,55 +51,12 @@ private[http] class RequestContextImpl( override def withUnmatchedPathMapped(f: Uri.Path ⇒ Uri.Path): RequestContext = copy(unmatchedPath = f(unmatchedPath)) - override def withRouteResponseMapped(f: RouteResult ⇒ RouteResult): RequestContext = - copy(finish = f andThen finish) - - override def withRouteResponseMappedPF(pf: PartialFunction[RouteResult, RouteResult]): RequestContext = - withRouteResponseMapped(pf.applyOrElse(_, identityFunc[RouteResult])) - - override def withRouteResponseFlatMapped(f: RouteResult ⇒ Future[RouteResult]): RequestContext = - copy(finish = rr ⇒ f(rr).fast.flatMap(finish)(executionContext)) - - override def withHttpResponseMapped(f: HttpResponse ⇒ HttpResponse): RequestContext = - withRouteResponseMappedPF { - case RouteResult.Complete(response) ⇒ RouteResult.complete(f(response)) - } - - override def withHttpResponseEntityMapped(f: ResponseEntity ⇒ ResponseEntity): RequestContext = - withHttpResponseMapped(_ mapEntity f) - - override def withHttpResponseHeadersMapped(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): RequestContext = - withHttpResponseMapped(_ mapHeaders f) - - override def withRejectionsMapped(f: immutable.Seq[Rejection] ⇒ immutable.Seq[Rejection]): RequestContext = - withRouteResponseMappedPF { - case RouteResult.Rejected(rejs) ⇒ RouteResult.rejected(f(rejs)) - } - - override def withRejectionHandling(f: immutable.Seq[Rejection] ⇒ Future[RouteResult]): RequestContext = - withRouteResponseHandling { - case RouteResult.Rejected(rejs) ⇒ - // `finish` is *not* chained in here, because the user already applied it when creating the result of f - f(rejs) - } - - override def withExceptionHandling(pf: PartialFunction[Throwable, Future[RouteResult]]): RequestContext = - withRouteResponseHandling { - case RouteResult.Failure(error) if pf isDefinedAt error ⇒ - // `finish` is *not* chained in here, because the user already applied it when creating the result of pf - pf(error) - } - - def withRouteResponseHandling(pf: PartialFunction[RouteResult, Future[RouteResult]]): RequestContext = - copy(finish = pf.applyOrElse(_, finish)) - override def withContentNegotiationDisabled: RequestContext = copy(request = request.withHeaders(request.headers filterNot (_.isInstanceOf[headers.Accept]))) private def copy(request: HttpRequest = request, unmatchedPath: Uri.Path = unmatchedPath, executionContext: ExecutionContext = executionContext, - log: LoggingAdapter = log, - finish: RouteResult ⇒ Future[RouteResult] = finish) = - new RequestContextImpl(request, unmatchedPath, executionContext, log, finish) + log: LoggingAdapter = log) = + new RequestContextImpl(request, unmatchedPath, executionContext, log) } diff --git a/akka-http/src/main/scala/akka/http/server/RouteConcatenation.scala b/akka-http/src/main/scala/akka/http/server/RouteConcatenation.scala index 9ed9305d2e..937140d425 100644 --- a/akka-http/src/main/scala/akka/http/server/RouteConcatenation.scala +++ b/akka-http/src/main/scala/akka/http/server/RouteConcatenation.scala @@ -9,18 +9,14 @@ trait RouteConcatenation { implicit def enhanceRouteWithConcatenation(route: Route) = new RouteConcatenation(route: Route) class RouteConcatenation(route: Route) { - /** * Returns a Route that chains two Routes. If the first Route rejects the request the second route is given a * chance to act upon the request. */ - def ~(other: Route): Route = { ctx ⇒ - route { - ctx.withRejectionHandling { rejections ⇒ - other(ctx.withRejectionsMapped(rejections ++ _)) - } - } - } + def ~(other: Route): Route = ctx ⇒ + route(ctx).recoverRejectionsWith(outerRejections ⇒ + other(ctx).recoverRejections(innerRejections ⇒ + RouteResult.rejected(outerRejections ++ innerRejections))(ctx.executionContext))(ctx.executionContext) } } diff --git a/akka-http/src/main/scala/akka/http/server/RouteResult.scala b/akka-http/src/main/scala/akka/http/server/RouteResult.scala index 6e23ecadc6..dba69d853f 100644 --- a/akka-http/src/main/scala/akka/http/server/RouteResult.scala +++ b/akka-http/src/main/scala/akka/http/server/RouteResult.scala @@ -18,9 +18,7 @@ sealed trait RouteResult object RouteResult { final case class Complete private[RouteResult] (response: HttpResponse) extends RouteResult - final case class Failure private[RouteResult] (exception: Throwable) extends RouteResult final case class Rejected private[RouteResult] (rejections: immutable.Seq[Rejection]) extends RouteResult private[http] def complete(response: HttpResponse) = Complete(response) - private[http] def failure(exception: Throwable) = Failure(exception) private[http] def rejected(rejections: immutable.Seq[Rejection]) = Rejected(rejections) } diff --git a/akka-http/src/main/scala/akka/http/server/ScalaRoutingDSL.scala b/akka-http/src/main/scala/akka/http/server/ScalaRoutingDSL.scala index 034fa602c2..f2dfacbc18 100644 --- a/akka-http/src/main/scala/akka/http/server/ScalaRoutingDSL.scala +++ b/akka-http/src/main/scala/akka/http/server/ScalaRoutingDSL.scala @@ -67,7 +67,6 @@ trait ScalaRoutingDSL extends Directives { sealedRoute(new RequestContextImpl(request, routingLog.requestLog(request))).fast.map { case RouteResult.Complete(response) ⇒ response case RouteResult.Rejected(rejected) ⇒ throw new IllegalStateException(s"Unhandled rejections '$rejected', unsealed RejectionHandler?!") - case RouteResult.Failure(error) ⇒ throw new IllegalStateException(s"Unhandled error '$error', unsealed ExceptionHandler?!") } } diff --git a/akka-http/src/main/scala/akka/http/server/directives/BasicDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/BasicDirectives.scala index aa15aff4dd..3019c6e64e 100644 --- a/akka-http/src/main/scala/akka/http/server/directives/BasicDirectives.scala +++ b/akka-http/src/main/scala/akka/http/server/directives/BasicDirectives.scala @@ -5,6 +5,9 @@ package akka.http.server package directives +import akka.http.util.FastFuture +import FastFuture._ + import scala.collection.immutable import akka.http.server.util.Tuple import akka.http.model._ @@ -22,22 +25,30 @@ trait BasicDirectives { mapRequestContext(_ withRequestMapped f) def mapRouteResponse(f: RouteResult ⇒ RouteResult): Directive0 = - mapRequestContext(_ withRouteResponseMapped f) + Directive.mapResult { (ctx, result) ⇒ + result.fast.map(f)(ctx.executionContext) + } def mapRouteResponsePF(f: PartialFunction[RouteResult, RouteResult]): Directive0 = - mapRequestContext(_ withRouteResponseMappedPF f) + mapRouteResponse { r ⇒ + if (f isDefinedAt r) f(r) else r + } def mapRejections(f: immutable.Seq[Rejection] ⇒ immutable.Seq[Rejection]): Directive0 = - mapRequestContext(_ withRejectionsMapped f) + Directive.mapResult { (ctx, result) ⇒ + result.recoverRejections(rejs ⇒ RouteResult.rejected(f(rejs)))(ctx.executionContext) + } def mapHttpResponse(f: HttpResponse ⇒ HttpResponse): Directive0 = - mapRequestContext(_ withHttpResponseMapped f) + Directive.mapResult { (ctx, result) ⇒ + result.mapResponse(r ⇒ RouteResult.complete(f(r)))(ctx.executionContext) + } def mapHttpResponseEntity(f: ResponseEntity ⇒ ResponseEntity): Directive0 = - mapRequestContext(_ withHttpResponseEntityMapped f) + mapHttpResponse(_.mapEntity(f)) def mapHttpResponseHeaders(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): Directive0 = - mapRequestContext(_ withHttpResponseHeadersMapped f) + mapHttpResponse(_.mapHeaders(f)) /** * A Directive0 that always passes the request on to its inner route diff --git a/akka-http/src/main/scala/akka/http/server/directives/ExecutionDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/ExecutionDirectives.scala index 0bd9dc16c6..49161dc427 100644 --- a/akka-http/src/main/scala/akka/http/server/directives/ExecutionDirectives.scala +++ b/akka-http/src/main/scala/akka/http/server/directives/ExecutionDirectives.scala @@ -13,15 +13,11 @@ trait ExecutionDirectives { * [[akka.http.server.ExceptionHandler]]. */ def handleExceptions(handler: ExceptionHandler): Directive0 = - mapInnerRoute { inner ⇒ - ctx ⇒ - def handleError = handler andThen (_(ctx.withContentNegotiationDisabled)) - try inner { - ctx withRouteResponseHandling { - case RouteResult.Failure(error) if handler isDefinedAt error ⇒ handleError(error) - } - } - catch handleError + Directive.mapResult { (ctx, result) ⇒ + def handleError = handler andThen (_(ctx.withContentNegotiationDisabled)) + result.recoverWith { + case error if handler isDefinedAt error ⇒ handleError(error) + }(ctx.executionContext) } /** @@ -29,17 +25,16 @@ trait ExecutionDirectives { * [[akka.http.server.RejectionHandler]]. */ def handleRejections(handler: RejectionHandler): Directive0 = - mapRequestContext { ctx ⇒ - ctx withRejectionHandling { rejections ⇒ - val filteredRejections = RejectionHandler.applyTransformations(rejections) - if (handler isDefinedAt filteredRejections) - handler(filteredRejections) { - ctx.withContentNegotiationDisabled withRejectionHandling { r ⇒ + Directive.mapResult { (ctx, result) ⇒ + result.recoverRejectionsWith { + case rejections ⇒ + val filteredRejections = RejectionHandler.applyTransformations(rejections) + if (handler isDefinedAt filteredRejections) + handler(filteredRejections)(ctx.withContentNegotiationDisabled).recoverRejections { r ⇒ sys.error(s"The RejectionHandler for $rejections must not itself produce rejections (received $r)!") - } - } - else ctx.reject(filteredRejections: _*) - } + }(ctx.executionContext) + else ctx.reject(filteredRejections: _*) + }(ctx.executionContext) } } diff --git a/akka-http/src/main/scala/akka/http/server/package.scala b/akka-http/src/main/scala/akka/http/server/package.scala index ab19fd1b0c..6aed6ff846 100644 --- a/akka-http/src/main/scala/akka/http/server/package.scala +++ b/akka-http/src/main/scala/akka/http/server/package.scala @@ -4,7 +4,14 @@ package akka.http -import scala.concurrent.Future +import scala.collection.immutable + +import scala.concurrent.{ ExecutionContext, Future } + +import akka.http.util.FastFuture +import FastFuture._ + +import akka.http.model.HttpResponse package object server { @@ -21,4 +28,24 @@ package object server { def Route(f: Route): Route = f def FIXME = throw new RuntimeException("Not yet implemented") + + private[http] implicit class EnhanceFutureRouteResult(val result: Future[RouteResult]) extends AnyVal { + def mapResponse(f: HttpResponse ⇒ RouteResult)(implicit ec: ExecutionContext): Future[RouteResult] = + mapResponseWith(response ⇒ FastFuture.successful(f(response))) + + def mapResponseWith(f: HttpResponse ⇒ Future[RouteResult])(implicit ec: ExecutionContext): Future[RouteResult] = + result.fast.flatMap { + case RouteResult.Complete(response) ⇒ f(response) + case r: RouteResult.Rejected ⇒ FastFuture.successful(r) + } + + def recoverRejections(f: immutable.Seq[Rejection] ⇒ RouteResult)(implicit ec: ExecutionContext): Future[RouteResult] = + recoverRejectionsWith(rej ⇒ FastFuture.successful(f(rej))) + + def recoverRejectionsWith(f: immutable.Seq[Rejection] ⇒ Future[RouteResult])(implicit ec: ExecutionContext): Future[RouteResult] = + result.fast.flatMap { + case c: RouteResult.Complete ⇒ FastFuture.successful(c) + case RouteResult.Rejected(rejections) ⇒ f(rejections) + } + } } \ No newline at end of file