diff --git a/akka-docs/rst/scala/code/docs/http/scaladsl/server/directives/HeaderDirectivesExamplesSpec.scala b/akka-docs/rst/scala/code/docs/http/scaladsl/server/directives/HeaderDirectivesExamplesSpec.scala index 695f23dcf8..e417ca35a4 100644 --- a/akka-docs/rst/scala/code/docs/http/scaladsl/server/directives/HeaderDirectivesExamplesSpec.scala +++ b/akka-docs/rst/scala/code/docs/http/scaladsl/server/directives/HeaderDirectivesExamplesSpec.scala @@ -209,13 +209,13 @@ class HeaderDirectivesExamplesSpec extends RoutingSpec with Inside { val invalidOriginHeader = Origin(invalidHttpOrigin) Get("abc") ~> invalidOriginHeader ~> route ~> check { inside(rejection) { - case InvalidOriginRejection(invalidOrigins) ⇒ - invalidOrigins shouldEqual Seq(invalidHttpOrigin) + case InvalidOriginRejection(allowedOrigins) ⇒ + allowedOrigins shouldEqual Seq(correctOrigin) } } Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check { status shouldEqual StatusCodes.Forbidden - responseAs[String] should include(s"${invalidHttpOrigin.value}") + responseAs[String] should include(s"${correctOrigin.value}") } } } diff --git a/akka-http-core/src/main/java/akka/http/javadsl/model/headers/HttpOriginRange.java b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/HttpOriginRange.java index 0c0ae05863..70a949883f 100644 --- a/akka-http-core/src/main/java/akka/http/javadsl/model/headers/HttpOriginRange.java +++ b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/HttpOriginRange.java @@ -11,18 +11,18 @@ import akka.http.impl.util.Util; * @see HttpOriginRanges for convenience access to often used values. */ public abstract class HttpOriginRange { - public abstract boolean matches(HttpOrigin origin); + public abstract boolean matches(HttpOrigin origin); - public static HttpOriginRange create(HttpOrigin... origins) { - return HttpOriginRange$.MODULE$.apply(Util.convertArray(origins)); - } + public static HttpOriginRange create(HttpOrigin... origins) { + return HttpOriginRange$.MODULE$.apply(Util.convertArray(origins)); + } - /** - * @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class). - * In some edge cases this field could end up containing a null value. - * Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead. - */ - @Deprecated - // FIXME: Remove in Akka 3.0 - public static final HttpOriginRange ALL = HttpOriginRanges.ALL; + /** + * @deprecated because of troublesome initialisation order (with regards to scaladsl class implementing this class). + * In some edge cases this field could end up containing a null value. + * Will be removed in Akka 3.x, use {@link HttpEncodingRanges#ALL} instead. + */ + @Deprecated + // FIXME: Remove in Akka 3.0 + public static final HttpOriginRange ALL = HttpOriginRanges.ALL; } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/HttpOrigin.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/HttpOrigin.scala index 27545e3848..bd32d48471 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/HttpOrigin.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/HttpOrigin.scala @@ -5,7 +5,6 @@ package akka.http.scaladsl.model.headers import akka.http.impl.model.JavaInitialization -import akka.util.Unsafe import language.implicitConversions import scala.collection.immutable @@ -22,6 +21,7 @@ abstract class HttpOriginRange extends jm.headers.HttpOriginRange with ValueRend /** Java API */ def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala) } + object HttpOriginRange { case object `*` extends HttpOriginRange { def matches(origin: HttpOrigin) = true @@ -43,6 +43,7 @@ object HttpOriginRange { final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable { def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://") } + object HttpOrigin { implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null") @@ -50,4 +51,4 @@ object HttpOrigin { val parser = new UriParser(str, UTF8, Uri.ParsingMode.Relaxed) parser.parseOrigin() } -} \ No newline at end of file +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/directives/HeaderDirectivesTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/directives/HeaderDirectivesTest.java index 8fe7fac43f..fe9ecad53f 100644 --- a/akka-http-tests/src/test/java/akka/http/javadsl/server/directives/HeaderDirectivesTest.java +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/directives/HeaderDirectivesTest.java @@ -205,5 +205,32 @@ public class HeaderDirectivesTest extends JUnitRouteTest { .run(HttpRequest.create().addHeader(Origin.create(invalidOriginHeader))) .assertStatusCode(StatusCodes.FORBIDDEN); } + + @Test + public void testCheckSameOriginGivenALL() { + final HttpOrigin validOriginHeader = HttpOrigin.create("http://localhost", Host.create("8080")); + + // not very interesting case, however here we check that the directive simply avoids performing the check + final HttpOriginRange everythingGoes = HttpOriginRanges.ALL; + + final TestRoute route = testRoute(checkSameOrigin(everythingGoes, () -> complete("Result"))); + + route + .run(HttpRequest.create().addHeader(Origin.create(validOriginHeader))) + .assertStatusCode(StatusCodes.OK) + .assertEntity("Result"); + + route + .run(HttpRequest.create()) + .assertStatusCode(StatusCodes.OK) + .assertEntity("Result"); + + final HttpOrigin otherOriginHeader = HttpOrigin.create("http://invalid.com", Host.create("8080")); + + route + .run(HttpRequest.create().addHeader(Origin.create(otherOriginHeader))) + .assertStatusCode(StatusCodes.OK) + .assertEntity("Result"); + } } diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala index 1c0b893b64..6b636192d8 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala @@ -193,12 +193,12 @@ class HeaderDirectivesSpec extends RoutingSpec with Inside { val invalidOriginHeader = Origin(invalidHttpOrigin) Get("abc") ~> invalidOriginHeader ~> route ~> check { inside(rejection) { - case InvalidOriginRejection(invalidOrigins) ⇒ invalidOrigins shouldEqual Seq(invalidHttpOrigin) + case InvalidOriginRejection(allowedOrigins) ⇒ allowedOrigins shouldEqual Seq(correctOrigin) } } Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check { status shouldEqual StatusCodes.Forbidden - responseAs[String] should include(s"${invalidHttpOrigin.value}") + responseAs[String] should include(s"${correctOrigin.value}") } } } diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala index 95d2e3b011..5dc4b11c6a 100644 --- a/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala @@ -110,7 +110,7 @@ trait MalformedHeaderRejection extends Rejection { * Signals that the request was rejected because `Origin` header value is invalid. */ trait InvalidOriginRejection extends Rejection { - def getInvalidOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin] + def getAllowedOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin] } /** diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/HeaderDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/HeaderDirectives.scala index ebc94285bb..5324b77a5e 100644 --- a/akka-http/src/main/scala/akka/http/javadsl/server/directives/HeaderDirectives.scala +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/HeaderDirectives.scala @@ -11,12 +11,12 @@ import akka.actor.ReflectiveDynamicAccess import scala.compat.java8.OptionConverters import scala.compat.java8.OptionConverters._ import akka.http.impl.util.JavaMapping.Implicits._ +import akka.http.javadsl.model.headers.{ HttpOriginRange, HttpOriginRanges } import akka.http.javadsl.model.{ HttpHeader, StatusCodes } -import akka.http.javadsl.model.headers.HttpOriginRange import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route } +import akka.http.scaladsl.model.headers.HttpOriginRange.Default import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin } -import akka.http.scaladsl.server.directives.{ HeaderMagnet, BasicDirectives ⇒ B, HeaderDirectives ⇒ D } -import akka.stream.ActorMaterializer +import akka.http.scaladsl.server.directives.{ HeaderMagnet, HeaderDirectives ⇒ D } import scala.reflect.ClassTag import scala.util.{ Failure, Success } @@ -33,9 +33,16 @@ abstract class HeaderDirectives extends FutureDirectives { * * @group header */ - def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route = RouteAdapter { - D.checkSameOrigin(allowed.asScala) { inner.get().delegate } - } + // TODO When breaking binary compatibility this should become HttpOriginRange.Default, see https://github.com/akka/akka/pull/20776/files#r70049845 + def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route = + allowed match { + case HttpOriginRanges.ALL | HttpOriginRange.ALL | akka.http.scaladsl.model.headers.HttpOriginRange.`*` ⇒ pass(inner) + case _ ⇒ RouteAdapter { + // safe, we know it's not the `*` header + val default = allowed.asInstanceOf[akka.http.scaladsl.model.headers.HttpOriginRange.Default] + D.checkSameOrigin(default) { inner.get().delegate } + } + } /** * Extracts an HTTP header value using the given function. If the function result is undefined for all headers the diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala index 7a2a266dca..eb912fc931 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala @@ -98,9 +98,9 @@ final case class MalformedHeaderRejection(headerName: String, errorMsg: String, * Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]]. * Signals that the request was rejected because `Origin` header value is invalid. */ -final case class InvalidOriginRejection(invalidOrigins: immutable.Seq[SHttpOrigin]) +final case class InvalidOriginRejection(allowedOrigins: immutable.Seq[SHttpOrigin]) extends jserver.InvalidOriginRejection with Rejection { - override def getInvalidOrigins: java.util.List[JHttpOrigin] = invalidOrigins.map(_.asJava).asJava + override def getAllowedOrigins: java.util.List[JHttpOrigin] = allowedOrigins.map(_.asJava).asJava } /** diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala index acdb5715af..e076c73a52 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala @@ -167,8 +167,8 @@ object RejectionHandler { complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\'')) } .handle { - case InvalidOriginRejection(invalidOrigin) ⇒ - complete((Forbidden, s"Invalid `Origin` header values: ${invalidOrigin.mkString(", ")}")) + case InvalidOriginRejection(allowedOrigins) ⇒ + complete((Forbidden, s"Allowed `Origin` header values: ${allowedOrigins.mkString(", ")}")) } .handle { case MissingQueryParamRejection(paramName) ⇒ diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala index 0243e58990..1fa50f54bd 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala @@ -28,10 +28,10 @@ trait HeaderDirectives { * * @group header */ - def checkSameOrigin(allowed: HttpOriginRange): Directive0 = { + def checkSameOrigin(allowed: HttpOriginRange.Default): Directive0 = { headerValueByType[Origin]().flatMap { origin ⇒ if (origin.origins.exists(allowed.matches)) pass - else reject(InvalidOriginRejection(origin.origins)) + else reject(InvalidOriginRejection(allowed.origins)) } }