Nikdon 20535 check same origin (#20962)

* =htp checkSameOrigin shows allowed origins

add HttpOriginRangeDefault into the javadsl and refactor

resolving binary compatibility + add copyright

return back public static final in the HttpOriginRange

* =htp #20535 address bin compat issues in checkSameOrigin PR
This commit is contained in:
Nikolay Donets 2016-07-15 13:38:47 +02:00 committed by Konrad Malawski
parent 3871e18acd
commit b7567a5c55
10 changed files with 67 additions and 32 deletions

View file

@ -209,13 +209,13 @@ class HeaderDirectivesExamplesSpec extends RoutingSpec with Inside {
val invalidOriginHeader = Origin(invalidHttpOrigin)
Get("abc") ~> invalidOriginHeader ~> route ~> check {
inside(rejection) {
case InvalidOriginRejection(invalidOrigins)
invalidOrigins shouldEqual Seq(invalidHttpOrigin)
case InvalidOriginRejection(allowedOrigins)
allowedOrigins shouldEqual Seq(correctOrigin)
}
}
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
status shouldEqual StatusCodes.Forbidden
responseAs[String] should include(s"${invalidHttpOrigin.value}")
responseAs[String] should include(s"${correctOrigin.value}")
}
}
}

View file

@ -5,7 +5,6 @@
package akka.http.scaladsl.model.headers
import akka.http.impl.model.JavaInitialization
import akka.util.Unsafe
import language.implicitConversions
import scala.collection.immutable
@ -22,6 +21,7 @@ abstract class HttpOriginRange extends jm.headers.HttpOriginRange with ValueRend
/** Java API */
def matches(origin: jm.headers.HttpOrigin): Boolean = matches(origin.asScala)
}
object HttpOriginRange {
case object `*` extends HttpOriginRange {
def matches(origin: HttpOrigin) = true
@ -43,6 +43,7 @@ object HttpOriginRange {
final case class HttpOrigin(scheme: String, host: Host) extends jm.headers.HttpOrigin with ValueRenderable {
def render[R <: Rendering](r: R): r.type = host.renderValue(r ~~ scheme ~~ "://")
}
object HttpOrigin {
implicit val originsRenderer: Renderer[immutable.Seq[HttpOrigin]] = Renderer.seqRenderer(" ", "null")

View file

@ -206,4 +206,31 @@ public class HeaderDirectivesTest extends JUnitRouteTest {
.assertStatusCode(StatusCodes.FORBIDDEN);
}
@Test
public void testCheckSameOriginGivenALL() {
final HttpOrigin validOriginHeader = HttpOrigin.create("http://localhost", Host.create("8080"));
// not very interesting case, however here we check that the directive simply avoids performing the check
final HttpOriginRange everythingGoes = HttpOriginRanges.ALL;
final TestRoute route = testRoute(checkSameOrigin(everythingGoes, () -> complete("Result")));
route
.run(HttpRequest.create().addHeader(Origin.create(validOriginHeader)))
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
route
.run(HttpRequest.create())
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
final HttpOrigin otherOriginHeader = HttpOrigin.create("http://invalid.com", Host.create("8080"));
route
.run(HttpRequest.create().addHeader(Origin.create(otherOriginHeader)))
.assertStatusCode(StatusCodes.OK)
.assertEntity("Result");
}
}

View file

@ -193,12 +193,12 @@ class HeaderDirectivesSpec extends RoutingSpec with Inside {
val invalidOriginHeader = Origin(invalidHttpOrigin)
Get("abc") ~> invalidOriginHeader ~> route ~> check {
inside(rejection) {
case InvalidOriginRejection(invalidOrigins) invalidOrigins shouldEqual Seq(invalidHttpOrigin)
case InvalidOriginRejection(allowedOrigins) allowedOrigins shouldEqual Seq(correctOrigin)
}
}
Get("abc") ~> invalidOriginHeader ~> Route.seal(route) ~> check {
status shouldEqual StatusCodes.Forbidden
responseAs[String] should include(s"${invalidHttpOrigin.value}")
responseAs[String] should include(s"${correctOrigin.value}")
}
}
}

View file

@ -110,7 +110,7 @@ trait MalformedHeaderRejection extends Rejection {
* Signals that the request was rejected because `Origin` header value is invalid.
*/
trait InvalidOriginRejection extends Rejection {
def getInvalidOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin]
def getAllowedOrigins: java.util.List[akka.http.javadsl.model.headers.HttpOrigin]
}
/**

View file

@ -11,12 +11,12 @@ import akka.actor.ReflectiveDynamicAccess
import scala.compat.java8.OptionConverters
import scala.compat.java8.OptionConverters._
import akka.http.impl.util.JavaMapping.Implicits._
import akka.http.javadsl.model.headers.{ HttpOriginRange, HttpOriginRanges }
import akka.http.javadsl.model.{ HttpHeader, StatusCodes }
import akka.http.javadsl.model.headers.HttpOriginRange
import akka.http.javadsl.server.{ InvalidOriginRejection, MissingHeaderRejection, Route }
import akka.http.scaladsl.model.headers.HttpOriginRange.Default
import akka.http.scaladsl.model.headers.{ ModeledCustomHeader, ModeledCustomHeaderCompanion, Origin }
import akka.http.scaladsl.server.directives.{ HeaderMagnet, BasicDirectives B, HeaderDirectives D }
import akka.stream.ActorMaterializer
import akka.http.scaladsl.server.directives.{ HeaderMagnet, HeaderDirectives D }
import scala.reflect.ClassTag
import scala.util.{ Failure, Success }
@ -33,8 +33,15 @@ abstract class HeaderDirectives extends FutureDirectives {
*
* @group header
*/
def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route = RouteAdapter {
D.checkSameOrigin(allowed.asScala) { inner.get().delegate }
// TODO When breaking binary compatibility this should become HttpOriginRange.Default, see https://github.com/akka/akka/pull/20776/files#r70049845
def checkSameOrigin(allowed: HttpOriginRange, inner: jf.Supplier[Route]): Route =
allowed match {
case HttpOriginRanges.ALL | HttpOriginRange.ALL | akka.http.scaladsl.model.headers.HttpOriginRange.`*` pass(inner)
case _ RouteAdapter {
// safe, we know it's not the `*` header
val default = allowed.asInstanceOf[akka.http.scaladsl.model.headers.HttpOriginRange.Default]
D.checkSameOrigin(default) { inner.get().delegate }
}
}
/**

View file

@ -98,9 +98,9 @@ final case class MalformedHeaderRejection(headerName: String, errorMsg: String,
* Rejection created by [[akka.http.scaladsl.server.directives.HeaderDirectives.checkSameOrigin]].
* Signals that the request was rejected because `Origin` header value is invalid.
*/
final case class InvalidOriginRejection(invalidOrigins: immutable.Seq[SHttpOrigin])
final case class InvalidOriginRejection(allowedOrigins: immutable.Seq[SHttpOrigin])
extends jserver.InvalidOriginRejection with Rejection {
override def getInvalidOrigins: java.util.List[JHttpOrigin] = invalidOrigins.map(_.asJava).asJava
override def getAllowedOrigins: java.util.List[JHttpOrigin] = allowedOrigins.map(_.asJava).asJava
}
/**

View file

@ -167,8 +167,8 @@ object RejectionHandler {
complete((BadRequest, "Request is missing required HTTP header '" + headerName + '\''))
}
.handle {
case InvalidOriginRejection(invalidOrigin)
complete((Forbidden, s"Invalid `Origin` header values: ${invalidOrigin.mkString(", ")}"))
case InvalidOriginRejection(allowedOrigins)
complete((Forbidden, s"Allowed `Origin` header values: ${allowedOrigins.mkString(", ")}"))
}
.handle {
case MissingQueryParamRejection(paramName)

View file

@ -28,10 +28,10 @@ trait HeaderDirectives {
*
* @group header
*/
def checkSameOrigin(allowed: HttpOriginRange): Directive0 = {
def checkSameOrigin(allowed: HttpOriginRange.Default): Directive0 = {
headerValueByType[Origin]().flatMap { origin
if (origin.origins.exists(allowed.matches)) pass
else reject(InvalidOriginRejection(origin.origins))
else reject(InvalidOriginRejection(allowed.origins))
}
}