From a2bcc0fb06568ebe9929750fa0e07235fc64dff8 Mon Sep 17 00:00:00 2001 From: Benjamin Thuillier Date: Fri, 26 Sep 2014 10:51:18 +0200 Subject: [PATCH 1/3] +htp #15926 Import HostDirectives from spray --- .../directives/HostDirectivesSpec.scala | 55 +++++++++++++++++ .../directives/PathDirectivesSpec.scala | 2 +- .../scala/akka/http/server/Directives.scala | 4 +- .../server/directives/HostDirectives.scala | 61 +++++++++++++++++++ 4 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 akka-http-tests/src/test/scala/akka/http/server/directives/HostDirectivesSpec.scala create mode 100644 akka-http/src/main/scala/akka/http/server/directives/HostDirectives.scala diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/HostDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/HostDirectivesSpec.scala new file mode 100644 index 0000000000..706a52df7f --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/HostDirectivesSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model.headers.Host +import org.scalatest.FreeSpec + +class HostDirectivesSpec extends FreeSpec with GenericRoutingSpec { + "The 'host' directive" - { + "in its simple String form should" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("spray.com") { completeOk } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass" in { + Get() ~> Host("spray.io") ~> { + host("spray.com", "spray.io") { completeOk } + } ~> check { response shouldEqual Ok } + } + } + + "in its simple RegEx form" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("hairspray.*".r) { echoComplete } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass and extract the full host" in { + Get() ~> Host("spray.io") ~> { + host("spra.*".r) { echoComplete } + } ~> check { responseAs[String] shouldEqual "spray.io" } + } + } + + "in its group RegEx form" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("hairspray(.*)".r) { echoComplete } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass and extract the full host" in { + Get() ~> Host("spray.io") ~> { + host("spra(.*)".r) { echoComplete } + } ~> check { responseAs[String] shouldEqual "y.io" } + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/PathDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/PathDirectivesSpec.scala index f472af4fd8..6676384edd 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/PathDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/PathDirectivesSpec.scala @@ -235,4 +235,4 @@ class PathDirectivesSpec extends RoutingSpec { case None ⇒ failTest("Example '" + exampleString + "' doesn't contain a test uri") } } -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index dbd4d0a3e2..1017e3fbc4 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -20,7 +20,7 @@ trait Directives extends RouteConcatenation //with FormFieldDirectives //with FutureDirectives //with HeaderDirectives - //with HostDirectives + with HostDirectives //with MarshallingDirectives with MethodDirectives //with MiscDirectives @@ -32,4 +32,4 @@ trait Directives extends RouteConcatenation //with SchemeDirectives //with SecurityDirectives -object Directives extends Directives \ No newline at end of file +object Directives extends Directives diff --git a/akka-http/src/main/scala/akka/http/server/directives/HostDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/HostDirectives.scala new file mode 100644 index 0000000000..26c274a7e6 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/HostDirectives.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import scala.util.matching.Regex +import akka.http.util._ + +trait HostDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Extracts the hostname part of the Host header value in the request. + */ + def hostName: Directive1[String] = HostDirectives._hostName + + /** + * Rejects all requests with a host name different from the given ones. + */ + def host(hostNames: String*): Directive0 = host(hostNames.contains(_)) + + /** + * Rejects all requests for whose host name the given predicate function returns false. + */ + def host(predicate: String ⇒ Boolean): Directive0 = hostName.require(predicate) + + /** + * Rejects all requests with a host name that doesn't have a prefix matching the given regular expression. + * For all matching requests the prefix string matching the regex is extracted and passed to the inner route. + * If the regex contains a capturing group only the string matched by this group is extracted. + * If the regex contains more than one capturing group an IllegalArgumentException is thrown. + */ + def host(regex: Regex): Directive1[String] = { + def forFunc(regexMatch: String ⇒ Option[String]): Directive1[String] = { + hostName.flatMap { name ⇒ + regexMatch(name) match { + case Some(matched) ⇒ provide(matched) + case None ⇒ reject + } + } + } + + regex.groupCount match { + case 0 ⇒ forFunc(regex.findPrefixOf(_)) + case 1 ⇒ forFunc(regex.findPrefixMatchOf(_).map(_.group(1))) + case _ ⇒ throw new IllegalArgumentException("Path regex '" + regex.pattern.pattern + + "' must not contain more than one capturing group") + } + } + +} + +object HostDirectives extends HostDirectives { + import BasicDirectives._ + + private val _hostName: Directive1[String] = + extract(_.request.uri.authority.host.address) +} From 6ad76226214bfa466071eff51f0cba3c4736a1c3 Mon Sep 17 00:00:00 2001 From: Benjamin Thuillier Date: Tue, 7 Oct 2014 14:05:12 +0200 Subject: [PATCH 2/3] +htp #15925 Import HeaderDirectives from spray --- .../directives/HeaderDirectivesSpec.scala | 100 ++++++++++++++++ .../scala/akka/http/server/Directives.scala | 2 +- .../http/server/directives/ClassMagnet.scala | 29 +++++ .../server/directives/HeaderDirectives.scala | 107 ++++++++++++++++++ 4 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 akka-http-tests/src/test/scala/akka/http/server/directives/HeaderDirectivesSpec.scala create mode 100644 akka-http/src/main/scala/akka/http/server/directives/ClassMagnet.scala create mode 100644 akka-http/src/main/scala/akka/http/server/directives/HeaderDirectives.scala diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/HeaderDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/HeaderDirectivesSpec.scala new file mode 100644 index 0000000000..70186a3b71 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/HeaderDirectivesSpec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server.directives + +import akka.http.model._ +import headers._ +import akka.http.server._ +import org.scalatest.Inside + +class HeaderDirectivesSpec extends RoutingSpec with Inside { + + "The headerValuePF directive" should { + lazy val myHeaderValue = headerValuePF { case Connection(tokens) ⇒ tokens.head } + + "extract the respective header value if a matching request header is present" in { + Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check { + responseAs[String] shouldEqual "close" + } + } + + "reject with an empty rejection set if no matching request header is present" in { + Get("/abc") ~> myHeaderValue { echoComplete } ~> check { rejections shouldEqual Nil } + } + + "reject with a MalformedHeaderRejection if the extract function throws an exception" in { + Get("/abc") ~> addHeader(Connection("close")) ~> { + (headerValuePF { case _ ⇒ sys.error("Naah!") }) { echoComplete } + } ~> check { + inside(rejection) { case MalformedHeaderRejection("Connection", "Naah!", _) ⇒ } + } + } + } + + "The headerValueByType directive" should { + lazy val route = + headerValueByType[Origin]() { origin ⇒ + complete(s"The first origin was ${origin.origins.head}") + } + "extract a header if the type is matching" in { + val originHeader = Origin(HttpOrigin("http://localhost:8080")) + Get("abc") ~> originHeader ~> route ~> check { + responseAs[String] shouldEqual "The first origin was http://localhost:8080" + } + } + "reject a request if no header of the given type is present" in { + Get("abc") ~> route ~> check { + inside(rejection) { + case MissingHeaderRejection("Origin") ⇒ + } + } + } + } + + "The optionalHeaderValue directive" should { + lazy val myHeaderValue = optionalHeaderValue { + case Connection(tokens) ⇒ Some(tokens.head) + case _ ⇒ None + } + + "extract the respective header value if a matching request header is present" in { + Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check { + responseAs[String] shouldEqual "Some(close)" + } + } + + "extract None if no matching request header is present" in { + Get("/abc") ~> myHeaderValue { echoComplete } ~> check { responseAs[String] shouldEqual "None" } + } + + "reject with a MalformedHeaderRejection if the extract function throws an exception" in { + Get("/abc") ~> addHeader(Connection("close")) ~> { + val myHeaderValue = optionalHeaderValue { case _ ⇒ sys.error("Naaah!") } + myHeaderValue { echoComplete } + } ~> check { + inside(rejection) { case MalformedHeaderRejection("Connection", "Naaah!", _) ⇒ } + } + } + } + + "The optionalHeaderValueByType directive" should { + val route = + optionalHeaderValueByType[Origin]() { + case Some(origin) ⇒ complete(s"The first origin was ${origin.origins.head}") + case None ⇒ complete("No Origin header found.") + } + "extract Some(header) if the type is matching" in { + val originHeader = Origin(HttpOrigin("http://localhost:8080")) + Get("abc") ~> originHeader ~> route ~> check { + responseAs[String] shouldEqual "The first origin was http://localhost:8080" + } + } + "extract None if no header of the given type is present" in { + Get("abc") ~> route ~> check { + responseAs[String] shouldEqual "No Origin header found." + } + } + } +} diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index 1017e3fbc4..fe2d6a4068 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -19,7 +19,7 @@ trait Directives extends RouteConcatenation //with FileAndResourceDirectives //with FormFieldDirectives //with FutureDirectives - //with HeaderDirectives + with HeaderDirectives with HostDirectives //with MarshallingDirectives with MethodDirectives diff --git a/akka-http/src/main/scala/akka/http/server/directives/ClassMagnet.scala b/akka-http/src/main/scala/akka/http/server/directives/ClassMagnet.scala new file mode 100644 index 0000000000..145ff6aeaf --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/ClassMagnet.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server.directives + +import scala.reflect.ClassTag + +/** A magnet that wraps a ClassTag */ +trait ClassMagnet[T] { + def classTag: ClassTag[T] + def runtimeClass: Class[T] + + /** + * Returns a partial function that checks if the input value is of runtime type + * T and returns the value if it does. Doesn't take erased information into account. + */ + def extractPF: PartialFunction[Any, T] +} +object ClassMagnet { + implicit def apply[T](u: Unit)(implicit tag: ClassTag[T]): ClassMagnet[T] = + new ClassMagnet[T] { + val classTag: ClassTag[T] = tag + val runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]] + val extractPF: PartialFunction[Any, T] = { + case x: T ⇒ x + } + } +} diff --git a/akka-http/src/main/scala/akka/http/server/directives/HeaderDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/HeaderDirectives.scala new file mode 100644 index 0000000000..b21c3dc047 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/HeaderDirectives.scala @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import scala.util.control.NonFatal +import akka.http.model._ +import akka.http.util._ + +trait HeaderDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Extracts an HTTP header value using the given function. If the function result is undefined for all headers the + * request is rejected with an empty rejection set. If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def headerValue[T](f: HttpHeader ⇒ Option[T]): Directive1[T] = { + val protectedF: HttpHeader ⇒ Option[Either[Rejection, T]] = header ⇒ + try f(header).map(Right.apply) + catch { + case NonFatal(e) ⇒ Some(Left(MalformedHeaderRejection(header.name, e.getMessage.nullAsEmpty, Some(e)))) + } + + extract(_.request.headers.collectFirst(Function.unlift(protectedF))).flatMap { + case Some(Right(a)) ⇒ provide(a) + case Some(Left(rejection)) ⇒ reject(rejection) + case None ⇒ reject + } + } + + /** + * Extracts an HTTP header value using the given partial function. If the function is undefined for all headers the + * request is rejected with an empty rejection set. + */ + def headerValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[T] = headerValue(pf.lift) + + /** + * Extracts the value of the HTTP request header with the given name. + * If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByName(headerName: Symbol): Directive1[String] = headerValueByName(headerName.toString) + + /** + * Extracts the value of the HTTP request header with the given name. + * If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByName(headerName: String): Directive1[String] = + headerValue(optionalValue(headerName.toLowerCase)) | reject(MissingHeaderRejection(headerName)) + + /** + * Extracts the HTTP request header of the given type. + * If no header with a matching type is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[T] = + headerValuePF(magnet.extractPF) | reject(MissingHeaderRejection(magnet.runtimeClass.getSimpleName)) + + /** + * Extracts an optional HTTP header value using the given function. + * If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def optionalHeaderValue[T](f: HttpHeader ⇒ Option[T]): Directive1[Option[T]] = + headerValue(f).map(Some(_): Option[T]).recoverPF { + case Nil ⇒ provide(None) + } + + /** + * Extracts an optional HTTP header value using the given partial function. + * If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def optionalHeaderValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[Option[T]] = + optionalHeaderValue(pf.lift) + + /** + * Extracts the value of the optional HTTP request header with the given name. + */ + def optionalHeaderValueByName(headerName: Symbol): Directive1[Option[String]] = + optionalHeaderValueByName(headerName.toString) + + /** + * Extracts the value of the optional HTTP request header with the given name. + */ + def optionalHeaderValueByName(headerName: String): Directive1[Option[String]] = { + val lowerCaseName = headerName.toLowerCase + extract(_.request.headers.collectFirst { + case HttpHeader(`lowerCaseName`, value) ⇒ value + }) + } + + /** + * Extract the header value of the optional HTTP request header with the given type. + */ + def optionalHeaderValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[Option[T]] = + optionalHeaderValuePF(magnet.extractPF) + + private def optionalValue(lowerCaseName: String): HttpHeader ⇒ Option[String] = { + case HttpHeader(`lowerCaseName`, value) ⇒ Some(value) + case _ ⇒ None + } +} + +object HeaderDirectives extends HeaderDirectives From b98f21cb807e6f3e7334d96e383422a29dd18b50 Mon Sep 17 00:00:00 2001 From: Benjamin Thuillier Date: Tue, 7 Oct 2014 14:08:31 +0200 Subject: [PATCH 3/3] +htp #15928 Import MiscDirectives from spray --- .../directives/MiscDirectivesSpec.scala | 33 ++++++++ .../scala/akka/http/server/Directive.scala | 2 +- .../scala/akka/http/server/Directives.scala | 2 +- .../server/directives/MiscDirectives.scala | 77 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 akka-http-tests/src/test/scala/akka/http/server/directives/MiscDirectivesSpec.scala create mode 100644 akka-http/src/main/scala/akka/http/server/directives/MiscDirectives.scala diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/MiscDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/MiscDirectivesSpec.scala new file mode 100644 index 0000000000..41bfea88dd --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/MiscDirectivesSpec.scala @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model._ +import headers._ +import HttpMethods._ +import MediaTypes._ +import Uri._ + +class MiscDirectivesSpec extends RoutingSpec { + + "the clientIP directive" should { + "extract from a X-Forwarded-For header" in { + Get() ~> addHeaders(`X-Forwarded-For`("2.3.4.5"), RawHeader("x-real-ip", "1.2.3.4")) ~> { + clientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "2.3.4.5" } + } + "extract from a Remote-Address header" in { + Get() ~> addHeaders(RawHeader("x-real-ip", "1.2.3.4"), `Remote-Address`(RemoteAddress("5.6.7.8"))) ~> { + clientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "5.6.7.8" } + } + "extract from a X-Real-IP header" in { + Get() ~> addHeader(RawHeader("x-real-ip", "1.2.3.4")) ~> { + clientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "1.2.3.4" } + } + } +} diff --git a/akka-http/src/main/scala/akka/http/server/Directive.scala b/akka-http/src/main/scala/akka/http/server/Directive.scala index 869d6d375e..8c05ca6875 100644 --- a/akka-http/src/main/scala/akka/http/server/Directive.scala +++ b/akka-http/src/main/scala/akka/http/server/Directive.scala @@ -141,4 +141,4 @@ object Directive { def filter(predicate: T ⇒ Boolean, rejections: Rejection*): Directive1[T] = underlying.tfilter({ case Tuple1(value) ⇒ predicate(value) }, rejections: _*) } -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index fe2d6a4068..fc5eed7dcc 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -23,7 +23,7 @@ trait Directives extends RouteConcatenation with HostDirectives //with MarshallingDirectives with MethodDirectives - //with MiscDirectives + with MiscDirectives //with ParameterDirectives with PathDirectives //with RangeDirectives diff --git a/akka-http/src/main/scala/akka/http/server/directives/MiscDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/MiscDirectives.scala new file mode 100644 index 0000000000..f6f435582e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/MiscDirectives.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.server +package directives + +import scala.reflect.{ classTag, ClassTag } +import akka.http.model._ +import akka.parboiled2.CharPredicate +import headers._ +import MediaTypes._ +import RouteResult._ + +trait MiscDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Returns a Directive which checks the given condition before passing on the [[spray.routing.RequestContext]] to + * its inner Route. If the condition fails the route is rejected with a [[spray.routing.ValidationRejection]]. + */ + def validate(check: ⇒ Boolean, errorMsg: String): Directive0 = + new Directive0 { + def tapply(f: Unit ⇒ Route) = if (check) f() else reject(ValidationRejection(errorMsg)) + } + + /** + * Directive extracting the IP of the client from either the X-Forwarded-For, Remote-Address or X-Real-IP header + * (in that order of priority). + */ + def clientIP: Directive1[RemoteAddress] = MiscDirectives._clientIP + + /** + * Rejects the request if its entity is not empty. + */ + def requestEntityEmpty: Directive0 = MiscDirectives._requestEntityEmpty + + /** + * Rejects empty requests with a RequestEntityExpectedRejection. + * Non-empty requests are passed on unchanged to the inner route. + */ + def requestEntityPresent: Directive0 = MiscDirectives._requestEntityPresent + + /** + * Converts responses with an empty entity into (empty) rejections. + * This way you can, for example, have the marshalling of a ''None'' option be treated as if the request could + * not be matched. + */ + def rejectEmptyResponse: Directive0 = MiscDirectives._rejectEmptyResponse +} + +object MiscDirectives extends MiscDirectives { + import BasicDirectives._ + import HeaderDirectives._ + import RouteDirectives._ + import CharPredicate._ + + private val validJsonpChars = AlphaNum ++ '.' ++ '_' ++ '$' + + private val _clientIP: Directive1[RemoteAddress] = + headerValuePF { case `X-Forwarded-For`(Seq(address, _*)) ⇒ address } | + headerValuePF { case `Remote-Address`(address) ⇒ address } | + headerValuePF { case h if h.is("x-real-ip") ⇒ RemoteAddress(h.value) } + + private val _requestEntityEmpty: Directive0 = + extract(_.request.entity.isKnownEmpty).flatMap(if (_) pass else reject) + + private val _requestEntityPresent: Directive0 = + extract(_.request.entity.isKnownEmpty).flatMap(if (_) reject else pass) + + private val _rejectEmptyResponse: Directive0 = + mapRouteResponse { + case Complete(response) if response.entity.isKnownEmpty ⇒ rejected(Nil) + case x ⇒ x + } +}