Merge pull request #15985 from bthuillier/feature/directives-migration
Spray Directives Migration
This commit is contained in:
commit
7cc9e9902f
10 changed files with 468 additions and 6 deletions
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server.directives
|
||||
|
||||
import akka.http.model._
|
||||
import headers._
|
||||
import akka.http.server._
|
||||
import org.scalatest.Inside
|
||||
|
||||
class HeaderDirectivesSpec extends RoutingSpec with Inside {
|
||||
|
||||
"The headerValuePF directive" should {
|
||||
lazy val myHeaderValue = headerValuePF { case Connection(tokens) ⇒ tokens.head }
|
||||
|
||||
"extract the respective header value if a matching request header is present" in {
|
||||
Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check {
|
||||
responseAs[String] shouldEqual "close"
|
||||
}
|
||||
}
|
||||
|
||||
"reject with an empty rejection set if no matching request header is present" in {
|
||||
Get("/abc") ~> myHeaderValue { echoComplete } ~> check { rejections shouldEqual Nil }
|
||||
}
|
||||
|
||||
"reject with a MalformedHeaderRejection if the extract function throws an exception" in {
|
||||
Get("/abc") ~> addHeader(Connection("close")) ~> {
|
||||
(headerValuePF { case _ ⇒ sys.error("Naah!") }) { echoComplete }
|
||||
} ~> check {
|
||||
inside(rejection) { case MalformedHeaderRejection("Connection", "Naah!", _) ⇒ }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"The headerValueByType directive" should {
|
||||
lazy val route =
|
||||
headerValueByType[Origin]() { origin ⇒
|
||||
complete(s"The first origin was ${origin.origins.head}")
|
||||
}
|
||||
"extract a header if the type is matching" in {
|
||||
val originHeader = Origin(HttpOrigin("http://localhost:8080"))
|
||||
Get("abc") ~> originHeader ~> route ~> check {
|
||||
responseAs[String] shouldEqual "The first origin was http://localhost:8080"
|
||||
}
|
||||
}
|
||||
"reject a request if no header of the given type is present" in {
|
||||
Get("abc") ~> route ~> check {
|
||||
inside(rejection) {
|
||||
case MissingHeaderRejection("Origin") ⇒
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"The optionalHeaderValue directive" should {
|
||||
lazy val myHeaderValue = optionalHeaderValue {
|
||||
case Connection(tokens) ⇒ Some(tokens.head)
|
||||
case _ ⇒ None
|
||||
}
|
||||
|
||||
"extract the respective header value if a matching request header is present" in {
|
||||
Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check {
|
||||
responseAs[String] shouldEqual "Some(close)"
|
||||
}
|
||||
}
|
||||
|
||||
"extract None if no matching request header is present" in {
|
||||
Get("/abc") ~> myHeaderValue { echoComplete } ~> check { responseAs[String] shouldEqual "None" }
|
||||
}
|
||||
|
||||
"reject with a MalformedHeaderRejection if the extract function throws an exception" in {
|
||||
Get("/abc") ~> addHeader(Connection("close")) ~> {
|
||||
val myHeaderValue = optionalHeaderValue { case _ ⇒ sys.error("Naaah!") }
|
||||
myHeaderValue { echoComplete }
|
||||
} ~> check {
|
||||
inside(rejection) { case MalformedHeaderRejection("Connection", "Naaah!", _) ⇒ }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"The optionalHeaderValueByType directive" should {
|
||||
val route =
|
||||
optionalHeaderValueByType[Origin]() {
|
||||
case Some(origin) ⇒ complete(s"The first origin was ${origin.origins.head}")
|
||||
case None ⇒ complete("No Origin header found.")
|
||||
}
|
||||
"extract Some(header) if the type is matching" in {
|
||||
val originHeader = Origin(HttpOrigin("http://localhost:8080"))
|
||||
Get("abc") ~> originHeader ~> route ~> check {
|
||||
responseAs[String] shouldEqual "The first origin was http://localhost:8080"
|
||||
}
|
||||
}
|
||||
"extract None if no header of the given type is present" in {
|
||||
Get("abc") ~> route ~> check {
|
||||
responseAs[String] shouldEqual "No Origin header found."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server
|
||||
package directives
|
||||
|
||||
import akka.http.model.headers.Host
|
||||
import org.scalatest.FreeSpec
|
||||
|
||||
class HostDirectivesSpec extends FreeSpec with GenericRoutingSpec {
|
||||
"The 'host' directive" - {
|
||||
"in its simple String form should" - {
|
||||
"block requests to unmatched hosts" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("spray.com") { completeOk }
|
||||
} ~> check { handled shouldEqual false }
|
||||
}
|
||||
|
||||
"let requests to matching hosts pass" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("spray.com", "spray.io") { completeOk }
|
||||
} ~> check { response shouldEqual Ok }
|
||||
}
|
||||
}
|
||||
|
||||
"in its simple RegEx form" - {
|
||||
"block requests to unmatched hosts" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("hairspray.*".r) { echoComplete }
|
||||
} ~> check { handled shouldEqual false }
|
||||
}
|
||||
|
||||
"let requests to matching hosts pass and extract the full host" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("spra.*".r) { echoComplete }
|
||||
} ~> check { responseAs[String] shouldEqual "spray.io" }
|
||||
}
|
||||
}
|
||||
|
||||
"in its group RegEx form" - {
|
||||
"block requests to unmatched hosts" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("hairspray(.*)".r) { echoComplete }
|
||||
} ~> check { handled shouldEqual false }
|
||||
}
|
||||
|
||||
"let requests to matching hosts pass and extract the full host" in {
|
||||
Get() ~> Host("spray.io") ~> {
|
||||
host("spra(.*)".r) { echoComplete }
|
||||
} ~> check { responseAs[String] shouldEqual "y.io" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server
|
||||
package directives
|
||||
|
||||
import akka.http.model._
|
||||
import headers._
|
||||
import HttpMethods._
|
||||
import MediaTypes._
|
||||
import Uri._
|
||||
|
||||
class MiscDirectivesSpec extends RoutingSpec {
|
||||
|
||||
"the clientIP directive" should {
|
||||
"extract from a X-Forwarded-For header" in {
|
||||
Get() ~> addHeaders(`X-Forwarded-For`("2.3.4.5"), RawHeader("x-real-ip", "1.2.3.4")) ~> {
|
||||
clientIP { echoComplete }
|
||||
} ~> check { responseAs[String] shouldEqual "2.3.4.5" }
|
||||
}
|
||||
"extract from a Remote-Address header" in {
|
||||
Get() ~> addHeaders(RawHeader("x-real-ip", "1.2.3.4"), `Remote-Address`(RemoteAddress("5.6.7.8"))) ~> {
|
||||
clientIP { echoComplete }
|
||||
} ~> check { responseAs[String] shouldEqual "5.6.7.8" }
|
||||
}
|
||||
"extract from a X-Real-IP header" in {
|
||||
Get() ~> addHeader(RawHeader("x-real-ip", "1.2.3.4")) ~> {
|
||||
clientIP { echoComplete }
|
||||
} ~> check { responseAs[String] shouldEqual "1.2.3.4" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -240,4 +240,4 @@ class PathDirectivesSpec extends RoutingSpec {
|
|||
case None ⇒ failTest("Example '" + exampleString + "' doesn't contain a test uri")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,4 +141,4 @@ object Directive {
|
|||
def filter(predicate: T ⇒ Boolean, rejections: Rejection*): Directive1[T] =
|
||||
underlying.tfilter({ case Tuple1(value) ⇒ predicate(value) }, rejections: _*)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ trait Directives extends RouteConcatenation
|
|||
//with FileAndResourceDirectives
|
||||
//with FormFieldDirectives
|
||||
//with FutureDirectives
|
||||
//with HeaderDirectives
|
||||
//with HostDirectives
|
||||
with HeaderDirectives
|
||||
with HostDirectives
|
||||
//with MarshallingDirectives
|
||||
with MethodDirectives
|
||||
//with MiscDirectives
|
||||
with MiscDirectives
|
||||
//with ParameterDirectives
|
||||
with PathDirectives
|
||||
//with RangeDirectives
|
||||
|
|
@ -32,4 +32,4 @@ trait Directives extends RouteConcatenation
|
|||
//with SchemeDirectives
|
||||
//with SecurityDirectives
|
||||
|
||||
object Directives extends Directives
|
||||
object Directives extends Directives
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server.directives
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
/** A magnet that wraps a ClassTag */
|
||||
trait ClassMagnet[T] {
|
||||
def classTag: ClassTag[T]
|
||||
def runtimeClass: Class[T]
|
||||
|
||||
/**
|
||||
* Returns a partial function that checks if the input value is of runtime type
|
||||
* T and returns the value if it does. Doesn't take erased information into account.
|
||||
*/
|
||||
def extractPF: PartialFunction[Any, T]
|
||||
}
|
||||
object ClassMagnet {
|
||||
implicit def apply[T](u: Unit)(implicit tag: ClassTag[T]): ClassMagnet[T] =
|
||||
new ClassMagnet[T] {
|
||||
val classTag: ClassTag[T] = tag
|
||||
val runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]]
|
||||
val extractPF: PartialFunction[Any, T] = {
|
||||
case x: T ⇒ x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server
|
||||
package directives
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
import akka.http.model._
|
||||
import akka.http.util._
|
||||
|
||||
trait HeaderDirectives {
|
||||
import BasicDirectives._
|
||||
import RouteDirectives._
|
||||
|
||||
/**
|
||||
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the
|
||||
* request is rejected with an empty rejection set. If the given function throws an exception the request is rejected
|
||||
* with a [[spray.routing.MalformedHeaderRejection]].
|
||||
*/
|
||||
def headerValue[T](f: HttpHeader ⇒ Option[T]): Directive1[T] = {
|
||||
val protectedF: HttpHeader ⇒ Option[Either[Rejection, T]] = header ⇒
|
||||
try f(header).map(Right.apply)
|
||||
catch {
|
||||
case NonFatal(e) ⇒ Some(Left(MalformedHeaderRejection(header.name, e.getMessage.nullAsEmpty, Some(e))))
|
||||
}
|
||||
|
||||
extract(_.request.headers.collectFirst(Function.unlift(protectedF))).flatMap {
|
||||
case Some(Right(a)) ⇒ provide(a)
|
||||
case Some(Left(rejection)) ⇒ reject(rejection)
|
||||
case None ⇒ reject
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts an HTTP header value using the given partial function. If the function is undefined for all headers the
|
||||
* request is rejected with an empty rejection set.
|
||||
*/
|
||||
def headerValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[T] = headerValue(pf.lift)
|
||||
|
||||
/**
|
||||
* Extracts the value of the HTTP request header with the given name.
|
||||
* If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
|
||||
*/
|
||||
def headerValueByName(headerName: Symbol): Directive1[String] = headerValueByName(headerName.toString)
|
||||
|
||||
/**
|
||||
* Extracts the value of the HTTP request header with the given name.
|
||||
* If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
|
||||
*/
|
||||
def headerValueByName(headerName: String): Directive1[String] =
|
||||
headerValue(optionalValue(headerName.toLowerCase)) | reject(MissingHeaderRejection(headerName))
|
||||
|
||||
/**
|
||||
* Extracts the HTTP request header of the given type.
|
||||
* If no header with a matching type is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
|
||||
*/
|
||||
def headerValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[T] =
|
||||
headerValuePF(magnet.extractPF) | reject(MissingHeaderRejection(magnet.runtimeClass.getSimpleName))
|
||||
|
||||
/**
|
||||
* Extracts an optional HTTP header value using the given function.
|
||||
* If the given function throws an exception the request is rejected
|
||||
* with a [[spray.routing.MalformedHeaderRejection]].
|
||||
*/
|
||||
def optionalHeaderValue[T](f: HttpHeader ⇒ Option[T]): Directive1[Option[T]] =
|
||||
headerValue(f).map(Some(_): Option[T]).recoverPF {
|
||||
case Nil ⇒ provide(None)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts an optional HTTP header value using the given partial function.
|
||||
* If the given function throws an exception the request is rejected
|
||||
* with a [[spray.routing.MalformedHeaderRejection]].
|
||||
*/
|
||||
def optionalHeaderValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[Option[T]] =
|
||||
optionalHeaderValue(pf.lift)
|
||||
|
||||
/**
|
||||
* Extracts the value of the optional HTTP request header with the given name.
|
||||
*/
|
||||
def optionalHeaderValueByName(headerName: Symbol): Directive1[Option[String]] =
|
||||
optionalHeaderValueByName(headerName.toString)
|
||||
|
||||
/**
|
||||
* Extracts the value of the optional HTTP request header with the given name.
|
||||
*/
|
||||
def optionalHeaderValueByName(headerName: String): Directive1[Option[String]] = {
|
||||
val lowerCaseName = headerName.toLowerCase
|
||||
extract(_.request.headers.collectFirst {
|
||||
case HttpHeader(`lowerCaseName`, value) ⇒ value
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the header value of the optional HTTP request header with the given type.
|
||||
*/
|
||||
def optionalHeaderValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[Option[T]] =
|
||||
optionalHeaderValuePF(magnet.extractPF)
|
||||
|
||||
private def optionalValue(lowerCaseName: String): HttpHeader ⇒ Option[String] = {
|
||||
case HttpHeader(`lowerCaseName`, value) ⇒ Some(value)
|
||||
case _ ⇒ None
|
||||
}
|
||||
}
|
||||
|
||||
object HeaderDirectives extends HeaderDirectives
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server
|
||||
package directives
|
||||
|
||||
import scala.util.matching.Regex
|
||||
import akka.http.util._
|
||||
|
||||
trait HostDirectives {
|
||||
import BasicDirectives._
|
||||
import RouteDirectives._
|
||||
|
||||
/**
|
||||
* Extracts the hostname part of the Host header value in the request.
|
||||
*/
|
||||
def hostName: Directive1[String] = HostDirectives._hostName
|
||||
|
||||
/**
|
||||
* Rejects all requests with a host name different from the given ones.
|
||||
*/
|
||||
def host(hostNames: String*): Directive0 = host(hostNames.contains(_))
|
||||
|
||||
/**
|
||||
* Rejects all requests for whose host name the given predicate function returns false.
|
||||
*/
|
||||
def host(predicate: String ⇒ Boolean): Directive0 = hostName.require(predicate)
|
||||
|
||||
/**
|
||||
* Rejects all requests with a host name that doesn't have a prefix matching the given regular expression.
|
||||
* For all matching requests the prefix string matching the regex is extracted and passed to the inner route.
|
||||
* If the regex contains a capturing group only the string matched by this group is extracted.
|
||||
* If the regex contains more than one capturing group an IllegalArgumentException is thrown.
|
||||
*/
|
||||
def host(regex: Regex): Directive1[String] = {
|
||||
def forFunc(regexMatch: String ⇒ Option[String]): Directive1[String] = {
|
||||
hostName.flatMap { name ⇒
|
||||
regexMatch(name) match {
|
||||
case Some(matched) ⇒ provide(matched)
|
||||
case None ⇒ reject
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
regex.groupCount match {
|
||||
case 0 ⇒ forFunc(regex.findPrefixOf(_))
|
||||
case 1 ⇒ forFunc(regex.findPrefixMatchOf(_).map(_.group(1)))
|
||||
case _ ⇒ throw new IllegalArgumentException("Path regex '" + regex.pattern.pattern +
|
||||
"' must not contain more than one capturing group")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object HostDirectives extends HostDirectives {
|
||||
import BasicDirectives._
|
||||
|
||||
private val _hostName: Directive1[String] =
|
||||
extract(_.request.uri.authority.host.address)
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.server
|
||||
package directives
|
||||
|
||||
import scala.reflect.{ classTag, ClassTag }
|
||||
import akka.http.model._
|
||||
import akka.parboiled2.CharPredicate
|
||||
import headers._
|
||||
import MediaTypes._
|
||||
import RouteResult._
|
||||
|
||||
trait MiscDirectives {
|
||||
import BasicDirectives._
|
||||
import RouteDirectives._
|
||||
|
||||
/**
|
||||
* Returns a Directive which checks the given condition before passing on the [[spray.routing.RequestContext]] to
|
||||
* its inner Route. If the condition fails the route is rejected with a [[spray.routing.ValidationRejection]].
|
||||
*/
|
||||
def validate(check: ⇒ Boolean, errorMsg: String): Directive0 =
|
||||
new Directive0 {
|
||||
def tapply(f: Unit ⇒ Route) = if (check) f() else reject(ValidationRejection(errorMsg))
|
||||
}
|
||||
|
||||
/**
|
||||
* Directive extracting the IP of the client from either the X-Forwarded-For, Remote-Address or X-Real-IP header
|
||||
* (in that order of priority).
|
||||
*/
|
||||
def clientIP: Directive1[RemoteAddress] = MiscDirectives._clientIP
|
||||
|
||||
/**
|
||||
* Rejects the request if its entity is not empty.
|
||||
*/
|
||||
def requestEntityEmpty: Directive0 = MiscDirectives._requestEntityEmpty
|
||||
|
||||
/**
|
||||
* Rejects empty requests with a RequestEntityExpectedRejection.
|
||||
* Non-empty requests are passed on unchanged to the inner route.
|
||||
*/
|
||||
def requestEntityPresent: Directive0 = MiscDirectives._requestEntityPresent
|
||||
|
||||
/**
|
||||
* Converts responses with an empty entity into (empty) rejections.
|
||||
* This way you can, for example, have the marshalling of a ''None'' option be treated as if the request could
|
||||
* not be matched.
|
||||
*/
|
||||
def rejectEmptyResponse: Directive0 = MiscDirectives._rejectEmptyResponse
|
||||
}
|
||||
|
||||
object MiscDirectives extends MiscDirectives {
|
||||
import BasicDirectives._
|
||||
import HeaderDirectives._
|
||||
import RouteDirectives._
|
||||
import CharPredicate._
|
||||
|
||||
private val validJsonpChars = AlphaNum ++ '.' ++ '_' ++ '$'
|
||||
|
||||
private val _clientIP: Directive1[RemoteAddress] =
|
||||
headerValuePF { case `X-Forwarded-For`(Seq(address, _*)) ⇒ address } |
|
||||
headerValuePF { case `Remote-Address`(address) ⇒ address } |
|
||||
headerValuePF { case h if h.is("x-real-ip") ⇒ RemoteAddress(h.value) }
|
||||
|
||||
private val _requestEntityEmpty: Directive0 =
|
||||
extract(_.request.entity.isKnownEmpty).flatMap(if (_) pass else reject)
|
||||
|
||||
private val _requestEntityPresent: Directive0 =
|
||||
extract(_.request.entity.isKnownEmpty).flatMap(if (_) reject else pass)
|
||||
|
||||
private val _rejectEmptyResponse: Directive0 =
|
||||
mapRouteResponse {
|
||||
case Complete(response) if response.entity.isKnownEmpty ⇒ rejected(Nil)
|
||||
case x ⇒ x
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue