Nikdon 20535 check same origin (#20962)

* =htp checkSameOrigin shows allowed origins

add HttpOriginRangeDefault into the javadsl and refactor

resolving binary compatibility + add copyright

return back public static final in the HttpOriginRange

* =htp #20535 address bin compat issues in checkSameOrigin PR
This commit is contained in:
Nikolay Donets 2016-07-15 13:38:47 +02:00 committed by Konrad Malawski
parent 3871e18acd
commit b7567a5c55
10 changed files with 67 additions and 32 deletions

View file

@ -209,13 +209,13 @@ class HeaderDirectivesExamplesSpec extends RoutingSpec with Inside {
val invalidOriginHeader = Origin(invalidHttpOrigin) val invalidOriginHeader = Origin(invalidHttpOrigin)
Get("abc") ~> invalidOriginHeader ~> route ~> check { Get("abc") ~> invalidOriginHeader ~> route ~> check {
inside(rejection) { inside(rejection) {
case InvalidOriginRejection(invalidOrigins) case InvalidOriginRejection(allowedOrigins)
invalidOrigins shouldEqual Seq(invalidHttpOrigin) allowedOrigins shouldEqual Seq(correctOrigin)
} }
} }
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check { Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
status shouldEqual StatusCodes.Forbidden status shouldEqual StatusCodes.Forbidden
responseAs[String] should include(s"${invalidHttpOrigin.value}") responseAs[String] should include(s"${correctOrigin.value}")
} }
} }
} }

View file

@ -11,18 +11,18 @@ import akka.http.impl.util.Util;
* @see HttpOriginRanges for convenience access to often used values. * @see HttpOriginRanges for convenience access to often used values.
*/ */
public abstract class HttpOriginRange { public abstract class HttpOriginRange {
public abstract boolean matches(HttpOrigin origin); public abstract boolean matches(HttpOrigin origin);
public static HttpOriginRange create(HttpOrigin... origins) { public static HttpOriginRange create(HttpOrigin... origins) {
return HttpOriginRange$.MODULE$.apply(Util.<HttpOrigin, akka.http.scaladsl.model.headers.HttpOrigin>convertArray(origins)); return HttpOriginRange$.MODULE$.apply(Util.<HttpOrigin, akka.http.scaladsl.model.headers.HttpOrigin>convertArray(origins));
} }
/** /**
* @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class). * @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class).
* In some edge cases this field could end up containing a null value. * In some edge cases this field could end up containing a null value.
* Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead. * Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead.
*/ */
@Deprecated @Deprecated
// FIXME: Remove in Akka 3.0 // FIXME: Remove in Akka 3.0
public static final HttpOriginRange ALL = HttpOriginRanges.ALL; public static final HttpOriginRange ALL = HttpOriginRanges.ALL;
} }

View file

@ -5,7 +5,6 @@
package akka.http.scaladsl.model.headers package akka.http.scaladsl.model.headers
import akka.http.impl.model.JavaInitialization import akka.http.impl.model.JavaInitialization
import akka.util.Unsafe
import language.implicitConversions import language.implicitConversions
import scala.collection.immutable import scala.collection.immutable
@ -22,6 +21,7 @@ abstract class HttpOriginRange extends jm.headers.HttpOriginRange with ValueRend
/** Java API */ /** Java API */
def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala) def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala)
} }
object HttpOriginRange { object HttpOriginRange {
case object `*` extends HttpOriginRange { case object `*` extends HttpOriginRange {
def matches(origin: HttpOrigin) = true def matches(origin: HttpOrigin) = true
@ -43,6 +43,7 @@ object HttpOriginRange {
final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable { final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable {
def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://") def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://")
} }
object HttpOrigin { object HttpOrigin {
implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null") implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null")
@ -50,4 +51,4 @@ object HttpOrigin {
val parser = new UriParser(str, UTF8, Uri.ParsingMode.Relaxed) val parser = new UriParser(str, UTF8, Uri.ParsingMode.Relaxed)
parser.parseOrigin() parser.parseOrigin()
} }
} }

View file

@ -205,5 +205,32 @@ public class HeaderDirectivesTest extends JUnitRouteTest {
.run(HttpRequest.create().addHeader(Origin.create(invalidOriginHeader))) .run(HttpRequest.create().addHeader(Origin.create(invalidOriginHeader)))
.assertStatusCode(StatusCodes.FORBIDDEN); .assertStatusCode(StatusCodes.FORBIDDEN);
} }
@Test
public void testCheckSameOriginGivenALL() {
final HttpOrigin validOriginHeader = HttpOrigin.create("http://localhost", Host.create("8080"));
// not very interesting case, however here we check that the directive simply avoids performing the check
final HttpOriginRange everythingGoes = HttpOriginRanges.ALL;
final TestRoute route = testRoute(checkSameOrigin(everythingGoes, () -> complete("Result")));
route
.run(HttpRequest.create().addHeader(Origin.create(validOriginHeader)))
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
route
.run(HttpRequest.create())
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
final HttpOrigin otherOriginHeader = HttpOrigin.create("http://invalid.com", Host.create("8080"));
route
.run(HttpRequest.create().addHeader(Origin.create(otherOriginHeader)))
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
}
} }

View file

@ -193,12 +193,12 @@ class HeaderDirectivesSpec extends RoutingSpec with Inside {
val invalidOriginHeader = Origin(invalidHttpOrigin) val invalidOriginHeader = Origin(invalidHttpOrigin)
Get("abc") ~> invalidOriginHeader ~> route ~> check { Get("abc") ~> invalidOriginHeader ~> route ~> check {
inside(rejection) { inside(rejection) {
case InvalidOriginRejection(invalidOrigins) invalidOrigins shouldEqual Seq(invalidHttpOrigin) case InvalidOriginRejection(allowedOrigins) allowedOrigins shouldEqual Seq(correctOrigin)
} }
} }
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check { Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
status shouldEqual StatusCodes.Forbidden status shouldEqual StatusCodes.Forbidden
responseAs[String] should include(s"${invalidHttpOrigin.value}") responseAs[String] should include(s"${correctOrigin.value}")
} }
} }
} }

View file

@ -110,7 +110,7 @@ trait MalformedHeaderRejection extends Rejection {
* Signals that the request was rejected because `Origin` header value is invalid. * Signals that the request was rejected because `Origin` header value is invalid.
*/ */
trait InvalidOriginRejection extends Rejection { trait InvalidOriginRejection extends Rejection {
def getInvalidOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin] def getAllowedOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin]
} }
/** /**

View file

@ -11,12 +11,12 @@ import akka.actor.ReflectiveDynamicAccess
import scala.compat.java8.OptionConverters import scala.compat.java8.OptionConverters
import scala.compat.java8.OptionConverters._ import scala.compat.java8.OptionConverters._
import akka.http.impl.util.JavaMapping.Implicits._ import akka.http.impl.util.JavaMapping.Implicits._
import akka.http.javadsl.model.headers.{ HttpOriginRange, HttpOriginRanges }
import akka.http.javadsl.model.{ HttpHeader, StatusCodes } import akka.http.javadsl.model.{ HttpHeader, StatusCodes }
import akka.http.javadsl.model.headers.HttpOriginRange
import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route } import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route }
import akka.http.scaladsl.model.headers.HttpOriginRange.Default
import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin } import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin }
import akka.http.scaladsl.server.directives.{ HeaderMagnet, BasicDirectives B, HeaderDirectives D } import akka.http.scaladsl.server.directives.{ HeaderMagnet, HeaderDirectives D }
import akka.stream.ActorMaterializer
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.util.{ Failure, Success } import scala.util.{ Failure, Success }
@ -33,9 +33,16 @@ abstract class HeaderDirectives extends FutureDirectives {
* *
* @group header * @group header
*/ */
def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route = RouteAdapter { // TODO When breaking binary compatibility this should become HttpOriginRange.Default, see https://github.com/akka/akka/pull/20776/files#r70049845
D.checkSameOrigin(allowed.asScala) { inner.get().delegate } def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route =
} allowed match {
case HttpOriginRanges.ALL | HttpOriginRange.ALL | akka.http.scaladsl.model.headers.HttpOriginRange.`*` pass(inner)
case _ RouteAdapter {
// safe, we know it's not the `*` header
val default = allowed.asInstanceOf[akka.http.scaladsl.model.headers.HttpOriginRange.Default]
D.checkSameOrigin(default) { inner.get().delegate }
}
}
/** /**
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the * Extracts an HTTP header value using the given function. If the function result is undefined for all headers the

View file

@ -98,9 +98,9 @@ final case class MalformedHeaderRejection(headerName: String, errorMsg: String,
* Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]]. * Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]].
* Signals that the request was rejected because `Origin` header value is invalid. * Signals that the request was rejected because `Origin` header value is invalid.
*/ */
final case class InvalidOriginRejection(invalidOrigins: immutable.Seq[SHttpOrigin]) final case class InvalidOriginRejection(allowedOrigins: immutable.Seq[SHttpOrigin])
extends jserver.InvalidOriginRejection with Rejection { extends jserver.InvalidOriginRejection with Rejection {
override def getInvalidOrigins: java.util.List[JHttpOrigin] = invalidOrigins.map(_.asJava).asJava override def getAllowedOrigins: java.util.List[JHttpOrigin] = allowedOrigins.map(_.asJava).asJava
} }
/** /**

View file

@ -167,8 +167,8 @@ object RejectionHandler {
complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\'')) complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\''))
} }
.handle { .handle {
case InvalidOriginRejection(invalidOrigin) case InvalidOriginRejection(allowedOrigins)
complete((Forbidden, s"Invalid `Origin` header values: ${invalidOrigin.mkString(", ")}")) complete((Forbidden, s"Allowed `Origin` header values: ${allowedOrigins.mkString(", ")}"))
} }
.handle { .handle {
case MissingQueryParamRejection(paramName) case MissingQueryParamRejection(paramName)

View file

@ -28,10 +28,10 @@ trait HeaderDirectives {
* *
* @group header * @group header
*/ */
def checkSameOrigin(allowed: HttpOriginRange): Directive0 = { def checkSameOrigin(allowed: HttpOriginRange.Default): Directive0 = {
headerValueByType[Origin]().flatMap { origin headerValueByType[Origin]().flatMap { origin
if (origin.origins.exists(allowed.matches)) pass if (origin.origins.exists(allowed.matches)) pass
else reject(InvalidOriginRejection(origin.origins)) else reject(InvalidOriginRejection(allowed.origins))
} }
} }