Merge pull request #15985 from bthuillier/feature/directives-migration

Spray Directives Migration
This commit is contained in:
Björn Antonsson 2014-10-09 13:46:50 +02:00
commit 7cc9e9902f
10 changed files with 468 additions and 6 deletions

View file

@ -0,0 +1,100 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server.directives
import akka.http.model._
import headers._
import akka.http.server._
import org.scalatest.Inside
class HeaderDirectivesSpec extends RoutingSpec with Inside {
"The headerValuePF directive" should {
lazy val myHeaderValue = headerValuePF { case Connection(tokens) tokens.head }
"extract the respective header value if a matching request header is present" in {
Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check {
responseAs[String] shouldEqual "close"
}
}
"reject with an empty rejection set if no matching request header is present" in {
Get("/abc") ~> myHeaderValue { echoComplete } ~> check { rejections shouldEqual Nil }
}
"reject with a MalformedHeaderRejection if the extract function throws an exception" in {
Get("/abc") ~> addHeader(Connection("close")) ~> {
(headerValuePF { case _ sys.error("Naah!") }) { echoComplete }
} ~> check {
inside(rejection) { case MalformedHeaderRejection("Connection", "Naah!", _) }
}
}
}
"The headerValueByType directive" should {
lazy val route =
headerValueByType[Origin]() { origin
complete(s"The first origin was ${origin.origins.head}")
}
"extract a header if the type is matching" in {
val originHeader = Origin(HttpOrigin("http://localhost:8080"))
Get("abc") ~> originHeader ~> route ~> check {
responseAs[String] shouldEqual "The first origin was http://localhost:8080"
}
}
"reject a request if no header of the given type is present" in {
Get("abc") ~> route ~> check {
inside(rejection) {
case MissingHeaderRejection("Origin")
}
}
}
}
"The optionalHeaderValue directive" should {
lazy val myHeaderValue = optionalHeaderValue {
case Connection(tokens) Some(tokens.head)
case _ None
}
"extract the respective header value if a matching request header is present" in {
Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check {
responseAs[String] shouldEqual "Some(close)"
}
}
"extract None if no matching request header is present" in {
Get("/abc") ~> myHeaderValue { echoComplete } ~> check { responseAs[String] shouldEqual "None" }
}
"reject with a MalformedHeaderRejection if the extract function throws an exception" in {
Get("/abc") ~> addHeader(Connection("close")) ~> {
val myHeaderValue = optionalHeaderValue { case _ sys.error("Naaah!") }
myHeaderValue { echoComplete }
} ~> check {
inside(rejection) { case MalformedHeaderRejection("Connection", "Naaah!", _) }
}
}
}
"The optionalHeaderValueByType directive" should {
val route =
optionalHeaderValueByType[Origin]() {
case Some(origin) complete(s"The first origin was ${origin.origins.head}")
case None complete("No Origin header found.")
}
"extract Some(header) if the type is matching" in {
val originHeader = Origin(HttpOrigin("http://localhost:8080"))
Get("abc") ~> originHeader ~> route ~> check {
responseAs[String] shouldEqual "The first origin was http://localhost:8080"
}
}
"extract None if no header of the given type is present" in {
Get("abc") ~> route ~> check {
responseAs[String] shouldEqual "No Origin header found."
}
}
}
}

View file

@ -0,0 +1,55 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import akka.http.model.headers.Host
import org.scalatest.FreeSpec
class HostDirectivesSpec extends FreeSpec with GenericRoutingSpec {
"The 'host' directive" - {
"in its simple String form should" - {
"block requests to unmatched hosts" in {
Get() ~> Host("spray.io") ~> {
host("spray.com") { completeOk }
} ~> check { handled shouldEqual false }
}
"let requests to matching hosts pass" in {
Get() ~> Host("spray.io") ~> {
host("spray.com", "spray.io") { completeOk }
} ~> check { response shouldEqual Ok }
}
}
"in its simple RegEx form" - {
"block requests to unmatched hosts" in {
Get() ~> Host("spray.io") ~> {
host("hairspray.*".r) { echoComplete }
} ~> check { handled shouldEqual false }
}
"let requests to matching hosts pass and extract the full host" in {
Get() ~> Host("spray.io") ~> {
host("spra.*".r) { echoComplete }
} ~> check { responseAs[String] shouldEqual "spray.io" }
}
}
"in its group RegEx form" - {
"block requests to unmatched hosts" in {
Get() ~> Host("spray.io") ~> {
host("hairspray(.*)".r) { echoComplete }
} ~> check { handled shouldEqual false }
}
"let requests to matching hosts pass and extract the full host" in {
Get() ~> Host("spray.io") ~> {
host("spra(.*)".r) { echoComplete }
} ~> check { responseAs[String] shouldEqual "y.io" }
}
}
}
}

View file

@ -0,0 +1,33 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import akka.http.model._
import headers._
import HttpMethods._
import MediaTypes._
import Uri._
class MiscDirectivesSpec extends RoutingSpec {
"the clientIP directive" should {
"extract from a X-Forwarded-For header" in {
Get() ~> addHeaders(`X-Forwarded-For`("2.3.4.5"), RawHeader("x-real-ip", "1.2.3.4")) ~> {
clientIP { echoComplete }
} ~> check { responseAs[String] shouldEqual "2.3.4.5" }
}
"extract from a Remote-Address header" in {
Get() ~> addHeaders(RawHeader("x-real-ip", "1.2.3.4"), `Remote-Address`(RemoteAddress("5.6.7.8"))) ~> {
clientIP { echoComplete }
} ~> check { responseAs[String] shouldEqual "5.6.7.8" }
}
"extract from a X-Real-IP header" in {
Get() ~> addHeader(RawHeader("x-real-ip", "1.2.3.4")) ~> {
clientIP { echoComplete }
} ~> check { responseAs[String] shouldEqual "1.2.3.4" }
}
}
}

View file

@ -240,4 +240,4 @@ class PathDirectivesSpec extends RoutingSpec {
case None failTest("Example '" + exampleString + "' doesn't contain a test uri")
}
}
}
}

View file

@ -141,4 +141,4 @@ object Directive {
def filter(predicate: T Boolean, rejections: Rejection*): Directive1[T] =
underlying.tfilter({ case Tuple1(value) predicate(value) }, rejections: _*)
}
}
}

View file

@ -19,11 +19,11 @@ trait Directives extends RouteConcatenation
//with FileAndResourceDirectives
//with FormFieldDirectives
//with FutureDirectives
//with HeaderDirectives
//with HostDirectives
with HeaderDirectives
with HostDirectives
//with MarshallingDirectives
with MethodDirectives
//with MiscDirectives
with MiscDirectives
//with ParameterDirectives
with PathDirectives
//with RangeDirectives
@ -32,4 +32,4 @@ trait Directives extends RouteConcatenation
//with SchemeDirectives
//with SecurityDirectives
object Directives extends Directives
object Directives extends Directives

View file

@ -0,0 +1,29 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server.directives
import scala.reflect.ClassTag
/** A magnet that wraps a ClassTag */
trait ClassMagnet[T] {
def classTag: ClassTag[T]
def runtimeClass: Class[T]
/**
* Returns a partial function that checks if the input value is of runtime type
* T and returns the value if it does. Doesn't take erased information into account.
*/
def extractPF: PartialFunction[Any, T]
}
object ClassMagnet {
implicit def apply[T](u: Unit)(implicit tag: ClassTag[T]): ClassMagnet[T] =
new ClassMagnet[T] {
val classTag: ClassTag[T] = tag
val runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]]
val extractPF: PartialFunction[Any, T] = {
case x: T x
}
}
}

View file

@ -0,0 +1,107 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import scala.util.control.NonFatal
import akka.http.model._
import akka.http.util._
trait HeaderDirectives {
import BasicDirectives._
import RouteDirectives._
/**
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the
* request is rejected with an empty rejection set. If the given function throws an exception the request is rejected
* with a [[spray.routing.MalformedHeaderRejection]].
*/
def headerValue[T](f: HttpHeader Option[T]): Directive1[T] = {
val protectedF: HttpHeader Option[Either[Rejection, T]] = header
try f(header).map(Right.apply)
catch {
case NonFatal(e) Some(Left(MalformedHeaderRejection(header.name, e.getMessage.nullAsEmpty, Some(e))))
}
extract(_.request.headers.collectFirst(Function.unlift(protectedF))).flatMap {
case Some(Right(a)) provide(a)
case Some(Left(rejection)) reject(rejection)
case None reject
}
}
/**
* Extracts an HTTP header value using the given partial function. If the function is undefined for all headers the
* request is rejected with an empty rejection set.
*/
def headerValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[T] = headerValue(pf.lift)
/**
* Extracts the value of the HTTP request header with the given name.
* If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
*/
def headerValueByName(headerName: Symbol): Directive1[String] = headerValueByName(headerName.toString)
/**
* Extracts the value of the HTTP request header with the given name.
* If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
*/
def headerValueByName(headerName: String): Directive1[String] =
headerValue(optionalValue(headerName.toLowerCase)) | reject(MissingHeaderRejection(headerName))
/**
* Extracts the HTTP request header of the given type.
* If no header with a matching type is found the request is rejected with a [[spray.routing.MissingHeaderRejection]].
*/
def headerValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[T] =
headerValuePF(magnet.extractPF) | reject(MissingHeaderRejection(magnet.runtimeClass.getSimpleName))
/**
* Extracts an optional HTTP header value using the given function.
* If the given function throws an exception the request is rejected
* with a [[spray.routing.MalformedHeaderRejection]].
*/
def optionalHeaderValue[T](f: HttpHeader Option[T]): Directive1[Option[T]] =
headerValue(f).map(Some(_): Option[T]).recoverPF {
case Nil provide(None)
}
/**
* Extracts an optional HTTP header value using the given partial function.
* If the given function throws an exception the request is rejected
* with a [[spray.routing.MalformedHeaderRejection]].
*/
def optionalHeaderValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[Option[T]] =
optionalHeaderValue(pf.lift)
/**
* Extracts the value of the optional HTTP request header with the given name.
*/
def optionalHeaderValueByName(headerName: Symbol): Directive1[Option[String]] =
optionalHeaderValueByName(headerName.toString)
/**
* Extracts the value of the optional HTTP request header with the given name.
*/
def optionalHeaderValueByName(headerName: String): Directive1[Option[String]] = {
val lowerCaseName = headerName.toLowerCase
extract(_.request.headers.collectFirst {
case HttpHeader(`lowerCaseName`, value) value
})
}
/**
* Extract the header value of the optional HTTP request header with the given type.
*/
def optionalHeaderValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[Option[T]] =
optionalHeaderValuePF(magnet.extractPF)
private def optionalValue(lowerCaseName: String): HttpHeader Option[String] = {
case HttpHeader(`lowerCaseName`, value) Some(value)
case _ None
}
}
object HeaderDirectives extends HeaderDirectives

View file

@ -0,0 +1,61 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import scala.util.matching.Regex
import akka.http.util._
trait HostDirectives {
import BasicDirectives._
import RouteDirectives._
/**
* Extracts the hostname part of the Host header value in the request.
*/
def hostName: Directive1[String] = HostDirectives._hostName
/**
* Rejects all requests with a host name different from the given ones.
*/
def host(hostNames: String*): Directive0 = host(hostNames.contains(_))
/**
* Rejects all requests for whose host name the given predicate function returns false.
*/
def host(predicate: String Boolean): Directive0 = hostName.require(predicate)
/**
* Rejects all requests with a host name that doesn't have a prefix matching the given regular expression.
* For all matching requests the prefix string matching the regex is extracted and passed to the inner route.
* If the regex contains a capturing group only the string matched by this group is extracted.
* If the regex contains more than one capturing group an IllegalArgumentException is thrown.
*/
def host(regex: Regex): Directive1[String] = {
def forFunc(regexMatch: String Option[String]): Directive1[String] = {
hostName.flatMap { name
regexMatch(name) match {
case Some(matched) provide(matched)
case None reject
}
}
}
regex.groupCount match {
case 0 forFunc(regex.findPrefixOf(_))
case 1 forFunc(regex.findPrefixMatchOf(_).map(_.group(1)))
case _ throw new IllegalArgumentException("Path regex '" + regex.pattern.pattern +
"' must not contain more than one capturing group")
}
}
}
object HostDirectives extends HostDirectives {
import BasicDirectives._
private val _hostName: Directive1[String] =
extract(_.request.uri.authority.host.address)
}

View file

@ -0,0 +1,77 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import scala.reflect.{ classTag, ClassTag }
import akka.http.model._
import akka.parboiled2.CharPredicate
import headers._
import MediaTypes._
import RouteResult._
trait MiscDirectives {
import BasicDirectives._
import RouteDirectives._
/**
* Returns a Directive which checks the given condition before passing on the [[spray.routing.RequestContext]] to
* its inner Route. If the condition fails the route is rejected with a [[spray.routing.ValidationRejection]].
*/
def validate(check: Boolean, errorMsg: String): Directive0 =
new Directive0 {
def tapply(f: Unit Route) = if (check) f() else reject(ValidationRejection(errorMsg))
}
/**
* Directive extracting the IP of the client from either the X-Forwarded-For, Remote-Address or X-Real-IP header
* (in that order of priority).
*/
def clientIP: Directive1[RemoteAddress] = MiscDirectives._clientIP
/**
* Rejects the request if its entity is not empty.
*/
def requestEntityEmpty: Directive0 = MiscDirectives._requestEntityEmpty
/**
* Rejects empty requests with a RequestEntityExpectedRejection.
* Non-empty requests are passed on unchanged to the inner route.
*/
def requestEntityPresent: Directive0 = MiscDirectives._requestEntityPresent
/**
* Converts responses with an empty entity into (empty) rejections.
* This way you can, for example, have the marshalling of a ''None'' option be treated as if the request could
* not be matched.
*/
def rejectEmptyResponse: Directive0 = MiscDirectives._rejectEmptyResponse
}
object MiscDirectives extends MiscDirectives {
import BasicDirectives._
import HeaderDirectives._
import RouteDirectives._
import CharPredicate._
private val validJsonpChars = AlphaNum ++ '.' ++ '_' ++ '$'
private val _clientIP: Directive1[RemoteAddress] =
headerValuePF { case `X-Forwarded-For`(Seq(address, _*)) address } |
headerValuePF { case `Remote-Address`(address) address } |
headerValuePF { case h if h.is("x-real-ip") RemoteAddress(h.value) }
private val _requestEntityEmpty: Directive0 =
extract(_.request.entity.isKnownEmpty).flatMap(if (_) pass else reject)
private val _requestEntityPresent: Directive0 =
extract(_.request.entity.isKnownEmpty).flatMap(if (_) reject else pass)
private val _rejectEmptyResponse: Directive0 =
mapRouteResponse {
case Complete(response) if response.entity.isKnownEmpty rejected(Nil)
case x x
}
}