Nikdon 20535 check same origin (#20962)
* =htp checkSameOrigin shows allowed origins add HttpOriginRangeDefault into the javadsl and refactor resolving binary compatibility + add copyright return back public static final in the HttpOriginRange * =htp #20535 address bin compat issues in checkSameOrigin PR
This commit is contained in:
parent
3871e18acd
commit
b7567a5c55
10 changed files with 67 additions and 32 deletions
|
|
@ -209,13 +209,13 @@ class HeaderDirectivesExamplesSpec extends RoutingSpec with Inside {
|
||||||
val invalidOriginHeader = Origin(invalidHttpOrigin)
|
val invalidOriginHeader = Origin(invalidHttpOrigin)
|
||||||
Get("abc") ~> invalidOriginHeader ~> route ~> check {
|
Get("abc") ~> invalidOriginHeader ~> route ~> check {
|
||||||
inside(rejection) {
|
inside(rejection) {
|
||||||
case InvalidOriginRejection(invalidOrigins) ⇒
|
case InvalidOriginRejection(allowedOrigins) ⇒
|
||||||
invalidOrigins shouldEqual Seq(invalidHttpOrigin)
|
allowedOrigins shouldEqual Seq(correctOrigin)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
|
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
|
||||||
status shouldEqual StatusCodes.Forbidden
|
status shouldEqual StatusCodes.Forbidden
|
||||||
responseAs[String] should include(s"${invalidHttpOrigin.value}")
|
responseAs[String] should include(s"${correctOrigin.value}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,18 +11,18 @@ import akka.http.impl.util.Util;
|
||||||
* @see HttpOriginRanges for convenience access to often used values.
|
* @see HttpOriginRanges for convenience access to often used values.
|
||||||
*/
|
*/
|
||||||
public abstract class HttpOriginRange {
|
public abstract class HttpOriginRange {
|
||||||
public abstract boolean matches(HttpOrigin origin);
|
public abstract boolean matches(HttpOrigin origin);
|
||||||
|
|
||||||
public static HttpOriginRange create(HttpOrigin... origins) {
|
public static HttpOriginRange create(HttpOrigin... origins) {
|
||||||
return HttpOriginRange$.MODULE$.apply(Util.<HttpOrigin, akka.http.scaladsl.model.headers.HttpOrigin>convertArray(origins));
|
return HttpOriginRange$.MODULE$.apply(Util.<HttpOrigin, akka.http.scaladsl.model.headers.HttpOrigin>convertArray(origins));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class).
|
* @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class).
|
||||||
* In some edge cases this field could end up containing a null value.
|
* In some edge cases this field could end up containing a null value.
|
||||||
* Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead.
|
* Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead.
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
// FIXME: Remove in Akka 3.0
|
// FIXME: Remove in Akka 3.0
|
||||||
public static final HttpOriginRange ALL = HttpOriginRanges.ALL;
|
public static final HttpOriginRange ALL = HttpOriginRanges.ALL;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
package akka.http.scaladsl.model.headers
|
package akka.http.scaladsl.model.headers
|
||||||
|
|
||||||
import akka.http.impl.model.JavaInitialization
|
import akka.http.impl.model.JavaInitialization
|
||||||
import akka.util.Unsafe
|
|
||||||
|
|
||||||
import language.implicitConversions
|
import language.implicitConversions
|
||||||
import scala.collection.immutable
|
import scala.collection.immutable
|
||||||
|
|
@ -22,6 +21,7 @@ abstract class HttpOriginRange extends jm.headers.HttpOriginRange with ValueRend
|
||||||
/** Java API */
|
/** Java API */
|
||||||
def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala)
|
def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala)
|
||||||
}
|
}
|
||||||
|
|
||||||
object HttpOriginRange {
|
object HttpOriginRange {
|
||||||
case object `*` extends HttpOriginRange {
|
case object `*` extends HttpOriginRange {
|
||||||
def matches(origin: HttpOrigin) = true
|
def matches(origin: HttpOrigin) = true
|
||||||
|
|
@ -43,6 +43,7 @@ object HttpOriginRange {
|
||||||
final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable {
|
final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable {
|
||||||
def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://")
|
def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://")
|
||||||
}
|
}
|
||||||
|
|
||||||
object HttpOrigin {
|
object HttpOrigin {
|
||||||
implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null")
|
implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null")
|
||||||
|
|
||||||
|
|
@ -50,4 +51,4 @@ object HttpOrigin {
|
||||||
val parser = new UriParser(str, UTF8, Uri.ParsingMode.Relaxed)
|
val parser = new UriParser(str, UTF8, Uri.ParsingMode.Relaxed)
|
||||||
parser.parseOrigin()
|
parser.parseOrigin()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -205,5 +205,32 @@ public class HeaderDirectivesTest extends JUnitRouteTest {
|
||||||
.run(HttpRequest.create().addHeader(Origin.create(invalidOriginHeader)))
|
.run(HttpRequest.create().addHeader(Origin.create(invalidOriginHeader)))
|
||||||
.assertStatusCode(StatusCodes.FORBIDDEN);
|
.assertStatusCode(StatusCodes.FORBIDDEN);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCheckSameOriginGivenALL() {
|
||||||
|
final HttpOrigin validOriginHeader = HttpOrigin.create("http://localhost", Host.create("8080"));
|
||||||
|
|
||||||
|
// not very interesting case, however here we check that the directive simply avoids performing the check
|
||||||
|
final HttpOriginRange everythingGoes = HttpOriginRanges.ALL;
|
||||||
|
|
||||||
|
final TestRoute route = testRoute(checkSameOrigin(everythingGoes, () -> complete("Result")));
|
||||||
|
|
||||||
|
route
|
||||||
|
.run(HttpRequest.create().addHeader(Origin.create(validOriginHeader)))
|
||||||
|
.assertStatusCode(StatusCodes.OK)
|
||||||
|
.assertEntity("Result");
|
||||||
|
|
||||||
|
route
|
||||||
|
.run(HttpRequest.create())
|
||||||
|
.assertStatusCode(StatusCodes.OK)
|
||||||
|
.assertEntity("Result");
|
||||||
|
|
||||||
|
final HttpOrigin otherOriginHeader = HttpOrigin.create("http://invalid.com", Host.create("8080"));
|
||||||
|
|
||||||
|
route
|
||||||
|
.run(HttpRequest.create().addHeader(Origin.create(otherOriginHeader)))
|
||||||
|
.assertStatusCode(StatusCodes.OK)
|
||||||
|
.assertEntity("Result");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -193,12 +193,12 @@ class HeaderDirectivesSpec extends RoutingSpec with Inside {
|
||||||
val invalidOriginHeader = Origin(invalidHttpOrigin)
|
val invalidOriginHeader = Origin(invalidHttpOrigin)
|
||||||
Get("abc") ~> invalidOriginHeader ~> route ~> check {
|
Get("abc") ~> invalidOriginHeader ~> route ~> check {
|
||||||
inside(rejection) {
|
inside(rejection) {
|
||||||
case InvalidOriginRejection(invalidOrigins) ⇒ invalidOrigins shouldEqual Seq(invalidHttpOrigin)
|
case InvalidOriginRejection(allowedOrigins) ⇒ allowedOrigins shouldEqual Seq(correctOrigin)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
|
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
|
||||||
status shouldEqual StatusCodes.Forbidden
|
status shouldEqual StatusCodes.Forbidden
|
||||||
responseAs[String] should include(s"${invalidHttpOrigin.value}")
|
responseAs[String] should include(s"${correctOrigin.value}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ trait MalformedHeaderRejection extends Rejection {
|
||||||
* Signals that the request was rejected because `Origin` header value is invalid.
|
* Signals that the request was rejected because `Origin` header value is invalid.
|
||||||
*/
|
*/
|
||||||
trait InvalidOriginRejection extends Rejection {
|
trait InvalidOriginRejection extends Rejection {
|
||||||
def getInvalidOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin]
|
def getAllowedOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,12 @@ import akka.actor.ReflectiveDynamicAccess
|
||||||
import scala.compat.java8.OptionConverters
|
import scala.compat.java8.OptionConverters
|
||||||
import scala.compat.java8.OptionConverters._
|
import scala.compat.java8.OptionConverters._
|
||||||
import akka.http.impl.util.JavaMapping.Implicits._
|
import akka.http.impl.util.JavaMapping.Implicits._
|
||||||
|
import akka.http.javadsl.model.headers.{ HttpOriginRange, HttpOriginRanges }
|
||||||
import akka.http.javadsl.model.{ HttpHeader, StatusCodes }
|
import akka.http.javadsl.model.{ HttpHeader, StatusCodes }
|
||||||
import akka.http.javadsl.model.headers.HttpOriginRange
|
|
||||||
import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route }
|
import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route }
|
||||||
|
import akka.http.scaladsl.model.headers.HttpOriginRange.Default
|
||||||
import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin }
|
import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin }
|
||||||
import akka.http.scaladsl.server.directives.{ HeaderMagnet, BasicDirectives ⇒ B, HeaderDirectives ⇒ D }
|
import akka.http.scaladsl.server.directives.{ HeaderMagnet, HeaderDirectives ⇒ D }
|
||||||
import akka.stream.ActorMaterializer
|
|
||||||
|
|
||||||
import scala.reflect.ClassTag
|
import scala.reflect.ClassTag
|
||||||
import scala.util.{ Failure, Success }
|
import scala.util.{ Failure, Success }
|
||||||
|
|
@ -33,9 +33,16 @@ abstract class HeaderDirectives extends FutureDirectives {
|
||||||
*
|
*
|
||||||
* @group header
|
* @group header
|
||||||
*/
|
*/
|
||||||
def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route = RouteAdapter {
|
// TODO When breaking binary compatibility this should become HttpOriginRange.Default, see https://github.com/akka/akka/pull/20776/files#r70049845
|
||||||
D.checkSameOrigin(allowed.asScala) { inner.get().delegate }
|
def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route =
|
||||||
}
|
allowed match {
|
||||||
|
case HttpOriginRanges.ALL | HttpOriginRange.ALL | akka.http.scaladsl.model.headers.HttpOriginRange.`*` ⇒ pass(inner)
|
||||||
|
case _ ⇒ RouteAdapter {
|
||||||
|
// safe, we know it's not the `*` header
|
||||||
|
val default = allowed.asInstanceOf[akka.http.scaladsl.model.headers.HttpOriginRange.Default]
|
||||||
|
D.checkSameOrigin(default) { inner.get().delegate }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the
|
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the
|
||||||
|
|
|
||||||
|
|
@ -98,9 +98,9 @@ final case class MalformedHeaderRejection(headerName: String, errorMsg: String,
|
||||||
* Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]].
|
* Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]].
|
||||||
* Signals that the request was rejected because `Origin` header value is invalid.
|
* Signals that the request was rejected because `Origin` header value is invalid.
|
||||||
*/
|
*/
|
||||||
final case class InvalidOriginRejection(invalidOrigins: immutable.Seq[SHttpOrigin])
|
final case class InvalidOriginRejection(allowedOrigins: immutable.Seq[SHttpOrigin])
|
||||||
extends jserver.InvalidOriginRejection with Rejection {
|
extends jserver.InvalidOriginRejection with Rejection {
|
||||||
override def getInvalidOrigins: java.util.List[JHttpOrigin] = invalidOrigins.map(_.asJava).asJava
|
override def getAllowedOrigins: java.util.List[JHttpOrigin] = allowedOrigins.map(_.asJava).asJava
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -167,8 +167,8 @@ object RejectionHandler {
|
||||||
complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\''))
|
complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\''))
|
||||||
}
|
}
|
||||||
.handle {
|
.handle {
|
||||||
case InvalidOriginRejection(invalidOrigin) ⇒
|
case InvalidOriginRejection(allowedOrigins) ⇒
|
||||||
complete((Forbidden, s"Invalid `Origin` header values: ${invalidOrigin.mkString(", ")}"))
|
complete((Forbidden, s"Allowed `Origin` header values: ${allowedOrigins.mkString(", ")}"))
|
||||||
}
|
}
|
||||||
.handle {
|
.handle {
|
||||||
case MissingQueryParamRejection(paramName) ⇒
|
case MissingQueryParamRejection(paramName) ⇒
|
||||||
|
|
|
||||||
|
|
@ -28,10 +28,10 @@ trait HeaderDirectives {
|
||||||
*
|
*
|
||||||
* @group header
|
* @group header
|
||||||
*/
|
*/
|
||||||
def checkSameOrigin(allowed: HttpOriginRange): Directive0 = {
|
def checkSameOrigin(allowed: HttpOriginRange.Default): Directive0 = {
|
||||||
headerValueByType[Origin]().flatMap { origin ⇒
|
headerValueByType[Origin]().flatMap { origin ⇒
|
||||||
if (origin.origins.exists(allowed.matches)) pass
|
if (origin.origins.exists(allowed.matches)) pass
|
||||||
else reject(InvalidOriginRejection(origin.origins))
|
else reject(InvalidOriginRejection(allowed.origins))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue