diff --git a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/JUnitRouteTest.scala b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/JUnitRouteTest.scala new file mode 100644 index 0000000000..75a9ad372e --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/JUnitRouteTest.scala @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.testkit + +import org.junit.rules.ExternalResource +import org.junit.{ Rule, Assert } +import scala.concurrent.duration._ +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.model.HttpResponse + +/** + * A RouteTest that uses JUnit assertions. + */ +abstract class JUnitRouteTestBase extends RouteTest { + protected def systemResource: ActorSystemResource + implicit def system: ActorSystem = systemResource.system + implicit def materializer: ActorFlowMaterializer = systemResource.materializer + + protected def createTestResponse(response: HttpResponse): TestResponse = + new TestResponse(response, awaitDuration)(system.dispatcher, materializer) { + protected def assertEquals(expected: AnyRef, actual: AnyRef, message: String): Unit = + Assert.assertEquals(message, expected, actual) + + protected def assertEquals(expected: Int, actual: Int, message: String): Unit = + Assert.assertEquals(message, expected, actual) + + protected def assertTrue(predicate: Boolean, message: String): Unit = + Assert.assertTrue(message, predicate) + + protected def fail(message: String): Nothing = { + Assert.fail(message) + throw new IllegalStateException("Assertion should have failed") + } + } +} +abstract class JUnitRouteTest extends JUnitRouteTestBase { + private[this] val _systemResource = new ActorSystemResource + @Rule + protected def systemResource: ActorSystemResource = _systemResource +} + +class ActorSystemResource extends ExternalResource { + protected def createSystem(): ActorSystem = ActorSystem() + protected def createFlowMaterializer(system: ActorSystem): ActorFlowMaterializer = ActorFlowMaterializer()(system) + + implicit def system: ActorSystem = _system + implicit def materializer: ActorFlowMaterializer = _materializer + + private[this] var _system: ActorSystem = null + private[this] var _materializer: ActorFlowMaterializer = null + + override def before(): Unit = { + require((_system eq null) && (_materializer eq null)) + _system = createSystem() + _materializer = createFlowMaterializer(_system) + } + override def after(): Unit = { + _system.shutdown() + _system.awaitTermination(5.seconds) + _system = null + _materializer = null + } +} diff --git a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/RouteTest.scala b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/RouteTest.scala new file mode 100644 index 0000000000..eef8688a8b --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/RouteTest.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.testkit + +import scala.annotation.varargs +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.server +import akka.http.javadsl.model.HttpRequest +import akka.http.javadsl.server.{ Route, Directives } +import akka.http.impl.util.JavaMapping.Implicits._ +import akka.http.impl.server.RouteImplementation +import akka.http.scaladsl.model.HttpResponse +import akka.http.scaladsl.server.{ RouteResult, RoutingSettings, Route ⇒ ScalaRoute } +import akka.actor.ActorSystem +import akka.event.NoLogging +import akka.http.impl.util._ + +abstract class RouteTest { + implicit def system: ActorSystem + implicit def materializer: ActorFlowMaterializer + implicit def executionContext: ExecutionContext = system.dispatcher + + protected def awaitDuration: FiniteDuration = 500.millis + + def runRoute(route: Route, request: HttpRequest): TestResponse = { + val scalaRoute = ScalaRoute.seal(RouteImplementation(route)) + val result = scalaRoute(new server.RequestContextImpl(request.asScala, NoLogging, RoutingSettings(system))) + + result.awaitResult(awaitDuration) match { + case RouteResult.Complete(response) ⇒ createTestResponse(response) + } + } + + @varargs + def testRoute(first: Route, others: Route*): TestRoute = + new TestRoute { + val underlying: Route = Directives.route(first, others: _*) + + def run(request: HttpRequest): TestResponse = runRoute(underlying, request) + } + + protected def createTestResponse(response: HttpResponse): TestResponse +} diff --git a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestResponse.scala b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestResponse.scala new file mode 100644 index 0000000000..ee4bbee1f3 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestResponse.scala @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.testkit + +import scala.reflect.ClassTag +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.FiniteDuration +import akka.util.ByteString +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.http.scaladsl.model.HttpResponse +import akka.http.impl.util._ +import akka.http.impl.server.UnmarshallerImpl +import akka.http.impl.util.JavaMapping.Implicits._ +import akka.http.javadsl.server.Unmarshaller +import akka.http.javadsl.model._ + +/** + * A wrapper for responses + */ +abstract class TestResponse(_response: HttpResponse, awaitAtMost: FiniteDuration)(implicit ec: ExecutionContext, materializer: ActorFlowMaterializer) { + lazy val entity: HttpEntityStrict = + _response.entity.toStrict(awaitAtMost).awaitResult(awaitAtMost) + lazy val response: HttpResponse = _response.withEntity(entity) + + // FIXME: add header getters / assertions + + def mediaType: MediaType = extractFromResponse(_.entity.contentType.mediaType) + def mediaTypeString: String = mediaType.toString + def entityBytes: ByteString = entity.data() + def entityAs[T](unmarshaller: Unmarshaller[T]): T = + Unmarshal(response) + .to(unmarshaller.asInstanceOf[UnmarshallerImpl[T]].scalaUnmarshaller(ec, materializer), ec) + .awaitResult(awaitAtMost) + def entityAsString: String = entity.data().utf8String + def status: StatusCode = response.status.asJava + def statusCode: Int = response.status.intValue + def header[T <: HttpHeader](clazz: Class[T]): T = + response.header(ClassTag(clazz)) + .getOrElse(fail(s"Expected header of type ${clazz.getSimpleName} but wasn't found.")) + + def assertStatusCode(expected: Int): TestResponse = + assertStatusCode(StatusCodes.get(expected)) + def assertStatusCode(expected: StatusCode): TestResponse = + assertEqualsKind(expected, status, "status code") + def assertMediaType(expected: String): TestResponse = + assertEqualsKind(expected, mediaTypeString, "media type") + def assertMediaType(expected: MediaType): TestResponse = + assertEqualsKind(expected, mediaType, "media type") + def assertEntity(expected: String): TestResponse = + assertEqualsKind(expected, entityAsString, "entity") + def assertEntityBytes(expected: ByteString): TestResponse = + assertEqualsKind(expected, entityBytes, "entity") + def assertEntityAs[T <: AnyRef](unmarshaller: Unmarshaller[T], expected: T): TestResponse = + assertEqualsKind(expected, entityAs(unmarshaller), "entity") + def assertHeaderExists(expected: HttpHeader): TestResponse = { + assertTrue(response.headers.exists(_ == expected), s"Header $expected was missing.") + this + } + def assertHeaderKindExists(name: String): TestResponse = { + val lowercased = name.toRootLowerCase + assertTrue(response.headers.exists(_.is(lowercased)), s"Expected `$name` header was missing.") + this + } + def assertHeaderExists(name: String, value: String): TestResponse = { + val lowercased = name.toRootLowerCase + val headers = response.headers.filter(_.is(lowercased)) + if (headers.isEmpty) fail(s"Expected `$name` header was missing.") + else assertTrue(headers.exists(_.value == value), + s"`$name` header was found but had the wrong value. Found headers: ${headers.mkString(", ")}") + + this + } + + private[this] def extractFromResponse[T](f: HttpResponse ⇒ T): T = + if (response eq null) fail("Request didn't complete with response") + else f(response) + + protected def assertEqualsKind(expected: AnyRef, actual: AnyRef, kind: String): TestResponse = { + assertEquals(expected, actual, s"Unexpected $kind!") + this + } + protected def assertEqualsKind(expected: Int, actual: Int, kind: String): TestResponse = { + assertEquals(expected, actual, s"Unexpected $kind!") + this + } + + protected def fail(message: String): Nothing + protected def assertEquals(expected: AnyRef, actual: AnyRef, message: String): Unit + protected def assertEquals(expected: Int, actual: Int, message: String): Unit + protected def assertTrue(predicate: Boolean, message: String): Unit +} \ No newline at end of file diff --git a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRoute.scala b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRoute.scala new file mode 100644 index 0000000000..68d329e396 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRoute.scala @@ -0,0 +1,13 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.testkit + +import akka.http.javadsl.model.HttpRequest +import akka.http.javadsl.server.Route + +trait TestRoute { + def underlying: Route + def run(request: HttpRequest): TestResponse +} \ No newline at end of file diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/MarshallingTestUtils.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/MarshallingTestUtils.scala new file mode 100644 index 0000000000..13416eb595 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/MarshallingTestUtils.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import scala.concurrent.duration._ +import scala.concurrent.{ ExecutionContext, Await } +import akka.http.scaladsl.unmarshalling.{ Unmarshal, FromEntityUnmarshaller } +import akka.http.scaladsl.marshalling._ +import akka.http.scaladsl.model.HttpEntity +import akka.stream.FlowMaterializer + +import scala.util.Try + +trait MarshallingTestUtils { + def marshal[T: ToEntityMarshaller](value: T)(implicit ec: ExecutionContext, mat: FlowMaterializer): HttpEntity.Strict = + Await.result(Marshal(value).to[HttpEntity].flatMap(_.toStrict(1.second)), 1.second) + + def unmarshalValue[T: FromEntityUnmarshaller](entity: HttpEntity)(implicit ec: ExecutionContext, mat: FlowMaterializer): T = + unmarshal(entity).get + + def unmarshal[T: FromEntityUnmarshaller](entity: HttpEntity)(implicit ec: ExecutionContext, mat: FlowMaterializer): Try[T] = { + val fut = Unmarshal(entity).to[T] + Await.ready(fut, 1.second) + fut.value.get + } +} + diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala new file mode 100644 index 0000000000..2ec70f1c74 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import com.typesafe.config.{ ConfigFactory, Config } +import scala.collection.immutable +import scala.concurrent.{ Await, Future } +import scala.concurrent.duration._ +import scala.util.DynamicVariable +import scala.reflect.ClassTag +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.client.RequestBuilding +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling._ +import akka.http.scaladsl.model._ +import headers.Host +import FastFuture._ + +trait RouteTest extends RequestBuilding with RouteTestResultComponent with MarshallingTestUtils { + this: TestFrameworkInterface ⇒ + + /** Override to supply a custom ActorSystem */ + protected def createActorSystem(): ActorSystem = + ActorSystem(actorSystemNameFrom(getClass), testConfig) + + def actorSystemNameFrom(clazz: Class[_]) = + clazz.getName + .replace('.', '-') + .replace('_', '-') + .filter(_ != '$') + + def testConfigSource: String = "" + def testConfig: Config = { + val source = testConfigSource + val config = if (source.isEmpty) ConfigFactory.empty() else ConfigFactory.parseString(source) + config.withFallback(ConfigFactory.load()) + } + implicit val system = createActorSystem() + implicit def executor = system.dispatcher + implicit val materializer = ActorFlowMaterializer() + + def cleanUp(): Unit = system.shutdown() + + private val dynRR = new DynamicVariable[RouteTestResult](null) + private def result = + if (dynRR.value ne null) dynRR.value + else sys.error("This value is only available inside of a `check` construct!") + + def check[T](body: ⇒ T): RouteTestResult ⇒ T = result ⇒ dynRR.withValue(result.awaitResult)(body) + + def handled: Boolean = result.handled + def response: HttpResponse = result.response + def responseEntity: HttpEntity = result.entity + def chunks: immutable.Seq[HttpEntity.ChunkStreamPart] = result.chunks + def entityAs[T: FromEntityUnmarshaller: ClassTag](implicit timeout: Duration = 1.second): T = { + def msg(e: Throwable) = s"Could not unmarshal entity to type '${implicitly[ClassTag[T]]}' for `entityAs` assertion: $e\n\nResponse was: $response" + Await.result(Unmarshal(responseEntity).to[T].fast.recover[T] { case error ⇒ failTest(msg(error)) }, timeout) + } + def responseAs[T: FromResponseUnmarshaller: ClassTag](implicit timeout: Duration = 1.second): T = { + def msg(e: Throwable) = s"Could not unmarshal response to type '${implicitly[ClassTag[T]]}' for `responseAs` assertion: $e\n\nResponse was: $response" + Await.result(Unmarshal(response).to[T].fast.recover[T] { case error ⇒ failTest(msg(error)) }, timeout) + } + def contentType: ContentType = responseEntity.contentType + def mediaType: MediaType = contentType.mediaType + def charset: HttpCharset = contentType.charset + def definedCharset: Option[HttpCharset] = contentType.definedCharset + def headers: immutable.Seq[HttpHeader] = response.headers + def header[T <: HttpHeader: ClassTag]: Option[T] = response.header[T] + def header(name: String): Option[HttpHeader] = response.headers.find(_.is(name.toLowerCase)) + def status: StatusCode = response.status + + def closingExtension: String = chunks.lastOption match { + case Some(HttpEntity.LastChunk(extension, _)) ⇒ extension + case _ ⇒ "" + } + def trailer: immutable.Seq[HttpHeader] = chunks.lastOption match { + case Some(HttpEntity.LastChunk(_, trailer)) ⇒ trailer + case _ ⇒ Nil + } + + def rejections: immutable.Seq[Rejection] = result.rejections + def rejection: Rejection = { + val r = rejections + if (r.size == 1) r.head else failTest("Expected a single rejection but got %s (%s)".format(r.size, r)) + } + + /** + * A dummy that can be used as `~> runRoute` to run the route but without blocking for the result. + * The result of the pipeline is the result that can later be checked with `check`. See the + * "separate running route from checking" example from ScalatestRouteTestSpec.scala. + */ + def runRoute: RouteTestResult ⇒ RouteTestResult = akka.http.impl.util.identityFunc + + // there is already an implicit class WithTransformation in scope (inherited from akka.http.scaladsl.testkit.TransformerPipelineSupport) + // however, this one takes precedence + implicit class WithTransformation2(request: HttpRequest) { + def ~>[A, B](f: A ⇒ B)(implicit ta: TildeArrow[A, B]): ta.Out = ta(request, f) + } + + abstract class TildeArrow[A, B] { + type Out + def apply(request: HttpRequest, f: A ⇒ B): Out + } + + case class DefaultHostInfo(host: Host, securedConnection: Boolean) + object DefaultHostInfo { + implicit def defaultHost: DefaultHostInfo = DefaultHostInfo(Host("example.com"), securedConnection = false) + } + object TildeArrow { + implicit object InjectIntoRequestTransformer extends TildeArrow[HttpRequest, HttpRequest] { + type Out = HttpRequest + def apply(request: HttpRequest, f: HttpRequest ⇒ HttpRequest) = f(request) + } + implicit def injectIntoRoute(implicit timeout: RouteTestTimeout, setup: RoutingSetup, + defaultHostInfo: DefaultHostInfo) = + new TildeArrow[RequestContext, Future[RouteResult]] { + type Out = RouteTestResult + def apply(request: HttpRequest, route: Route): Out = { + val routeTestResult = new RouteTestResult(timeout.duration) + val effectiveRequest = + request.withEffectiveUri( + securedConnection = defaultHostInfo.securedConnection, + defaultHostHeader = defaultHostInfo.host) + val ctx = new RequestContextImpl(effectiveRequest, setup.routingLog.requestLog(effectiveRequest), setup.settings) + val sealedExceptionHandler = setup.exceptionHandler.seal(setup.settings) + val semiSealedRoute = // sealed for exceptions but not for rejections + Directives.handleExceptions(sealedExceptionHandler) { route } + val deferrableRouteResult = semiSealedRoute(ctx) + deferrableRouteResult.fast.foreach(routeTestResult.handleResult)(setup.executor) + routeTestResult + } + } + } +} + +//FIXME: trait Specs2RouteTest extends RouteTest with Specs2Interface diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestResultComponent.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestResultComponent.scala new file mode 100644 index 0000000000..af58f64294 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestResultComponent.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import java.util.concurrent.CountDownLatch +import scala.collection.immutable +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext +import akka.stream.FlowMaterializer +import akka.stream.scaladsl._ +import akka.http.scaladsl.model.HttpEntity.ChunkStreamPart +import akka.http.scaladsl.server._ +import akka.http.scaladsl.model._ +import akka.http.impl.util._ + +trait RouteTestResultComponent { + + def failTest(msg: String): Nothing + + /** + * A receptacle for the response or rejections created by a route. + */ + class RouteTestResult(timeout: FiniteDuration)(implicit fm: FlowMaterializer) { + private[this] var result: Option[Either[immutable.Seq[Rejection], HttpResponse]] = None + private[this] val latch = new CountDownLatch(1) + + def handled: Boolean = synchronized { result.isDefined && result.get.isRight } + + def rejections: immutable.Seq[Rejection] = synchronized { + result match { + case Some(Left(rejections)) ⇒ rejections + case Some(Right(response)) ⇒ failTest("Request was not rejected, response was " + response) + case None ⇒ failNeitherCompletedNorRejected() + } + } + + def response: HttpResponse = rawResponse.copy(entity = entity) + + /** Returns a "fresh" entity with a "fresh" unconsumed byte- or chunk stream (if not strict) */ + def entity: ResponseEntity = entityRecreator() + + def chunks: immutable.Seq[ChunkStreamPart] = + entity match { + case HttpEntity.Chunked(_, chunks) ⇒ awaitAllElements[ChunkStreamPart](chunks) + case _ ⇒ Nil + } + + def ~>[T](f: RouteTestResult ⇒ T): T = f(this) + + private def rawResponse: HttpResponse = synchronized { + result match { + case Some(Right(response)) ⇒ response + case Some(Left(Nil)) ⇒ failTest("Request was rejected") + case Some(Left(rejection :: Nil)) ⇒ failTest("Request was rejected with rejection " + rejection) + case Some(Left(rejections)) ⇒ failTest("Request was rejected with rejections " + rejections) + case None ⇒ failNeitherCompletedNorRejected() + } + } + + private[testkit] def handleResult(rr: RouteResult)(implicit ec: ExecutionContext): Unit = + synchronized { + if (result.isEmpty) { + result = rr match { + case RouteResult.Complete(response) ⇒ Some(Right(response)) + case RouteResult.Rejected(rejections) ⇒ Some(Left(RejectionHandler.applyTransformations(rejections))) + } + latch.countDown() + } else failTest("Route completed/rejected more than once") + } + + private[testkit] def awaitResult: this.type = { + latch.await(timeout.toMillis, MILLISECONDS) + this + } + + private[this] lazy val entityRecreator: () ⇒ ResponseEntity = + rawResponse.entity match { + case s: HttpEntity.Strict ⇒ () ⇒ s + + case HttpEntity.Default(contentType, contentLength, data) ⇒ + val dataChunks = awaitAllElements(data); + { () ⇒ HttpEntity.Default(contentType, contentLength, Source(dataChunks)) } + + case HttpEntity.CloseDelimited(contentType, data) ⇒ + val dataChunks = awaitAllElements(data); + { () ⇒ HttpEntity.CloseDelimited(contentType, Source(dataChunks)) } + + case HttpEntity.Chunked(contentType, chunks) ⇒ + val dataChunks = awaitAllElements(chunks); + { () ⇒ HttpEntity.Chunked(contentType, Source(dataChunks)) } + } + + private def failNeitherCompletedNorRejected(): Nothing = + failTest("Request was neither completed nor rejected within " + timeout) + + private def awaitAllElements[T](data: Source[T, _]): immutable.Seq[T] = + data.grouped(100000).runWith(Sink.head).awaitResult(timeout) + } +} \ No newline at end of file diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestTimeout.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestTimeout.scala new file mode 100644 index 0000000000..3773c30b32 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTestTimeout.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import scala.concurrent.duration._ +import akka.actor.ActorSystem +import akka.testkit._ + +case class RouteTestTimeout(duration: FiniteDuration) + +object RouteTestTimeout { + implicit def default(implicit system: ActorSystem) = RouteTestTimeout(1.second dilated) +} diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/ScalatestUtils.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/ScalatestUtils.scala new file mode 100644 index 0000000000..04c583849a --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/ScalatestUtils.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import scala.util.Try +import scala.concurrent.{ ExecutionContext, Future, Await } +import scala.concurrent.duration._ +import org.scalatest.Suite +import org.scalatest.matchers.Matcher +import akka.stream.FlowMaterializer +import akka.http.scaladsl.model.HttpEntity +import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller + +trait ScalatestUtils extends MarshallingTestUtils { + import org.scalatest.Matchers._ + def evaluateTo[T](value: T): Matcher[Future[T]] = + equal(value).matcher[T] compose (x ⇒ Await.result(x, 1.second)) + + def haveFailedWith(t: Throwable): Matcher[Future[_]] = + equal(t).matcher[Throwable] compose (x ⇒ Await.result(x.failed, 1.second)) + + def unmarshalToValue[T: FromEntityUnmarshaller](value: T)(implicit ec: ExecutionContext, mat: FlowMaterializer): Matcher[HttpEntity] = + equal(value).matcher[T] compose (unmarshalValue(_)) + + def unmarshalTo[T: FromEntityUnmarshaller](value: Try[T])(implicit ec: ExecutionContext, mat: FlowMaterializer): Matcher[HttpEntity] = + equal(value).matcher[Try[T]] compose (unmarshal(_)) +} + +trait ScalatestRouteTest extends RouteTest with TestFrameworkInterface.Scalatest with ScalatestUtils { this: Suite ⇒ } \ No newline at end of file diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/TestFrameworkInterface.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/TestFrameworkInterface.scala new file mode 100644 index 0000000000..78ed7ba88d --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/TestFrameworkInterface.scala @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import org.scalatest.exceptions.TestFailedException +import org.scalatest.{ BeforeAndAfterAll, Suite } + +//# source-quote +trait TestFrameworkInterface { + + def cleanUp() + + def failTest(msg: String): Nothing +} +//# + +object TestFrameworkInterface { + + trait Scalatest extends TestFrameworkInterface with BeforeAndAfterAll { + this: Suite ⇒ + + def failTest(msg: String) = throw new TestFailedException(msg, 11) + + abstract override protected def afterAll(): Unit = { + cleanUp() + super.afterAll() + } + } + +} diff --git a/akka-http-testkit/src/test/scala/akka/http/scaladsl/testkit/ScalatestRouteTestSpec.scala b/akka-http-testkit/src/test/scala/akka/http/scaladsl/testkit/ScalatestRouteTestSpec.scala new file mode 100644 index 0000000000..cb302186f4 --- /dev/null +++ b/akka-http-testkit/src/test/scala/akka/http/scaladsl/testkit/ScalatestRouteTestSpec.scala @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import scala.concurrent.duration._ +import org.scalatest.FreeSpec +import org.scalatest.Matchers +import akka.testkit.TestProbe +import akka.util.Timeout +import akka.pattern.ask +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.server._ +import akka.http.scaladsl.model._ +import StatusCodes._ +import HttpMethods._ +import Directives._ + +class ScalatestRouteTestSpec extends FreeSpec with Matchers with ScalatestRouteTest { + + "The ScalatestRouteTest should support" - { + + "the most simple and direct route test" in { + Get() ~> complete(HttpResponse()) ~> { rr ⇒ rr.awaitResult; rr.response } shouldEqual HttpResponse() + } + + "a test using a directive and some checks" in { + val pinkHeader = RawHeader("Fancy", "pink") + Get() ~> addHeader(pinkHeader) ~> { + respondWithHeader(pinkHeader) { + complete("abc") + } + } ~> check { + status shouldEqual OK + responseEntity shouldEqual HttpEntity(ContentTypes.`text/plain(UTF-8)`, "abc") + header("Fancy") shouldEqual Some(pinkHeader) + } + } + + "proper rejection collection" in { + Post("/abc", "content") ~> { + (get | put) { + complete("naah") + } + } ~> check { + rejections shouldEqual List(MethodRejection(GET), MethodRejection(PUT)) + } + } + + "separation of route execution from checking" in { + val pinkHeader = RawHeader("Fancy", "pink") + + case object Command + val service = TestProbe() + val handler = TestProbe() + implicit def serviceRef = service.ref + implicit val askTimeout: Timeout = 1.second + + val result = + Get() ~> pinkHeader ~> { + respondWithHeader(pinkHeader) { + complete(handler.ref.ask(Command).mapTo[String]) + } + } ~> runRoute + + handler.expectMsg(Command) + handler.reply("abc") + + check { + status shouldEqual OK + responseEntity shouldEqual HttpEntity(ContentTypes.`text/plain(UTF-8)`, "abc") + header("Fancy") shouldEqual Some(pinkHeader) + }(result) + } + } + + // TODO: remove once RespondWithDirectives have been ported + def respondWithHeader(responseHeader: HttpHeader): Directive0 = + mapResponseHeaders(responseHeader +: _) +} diff --git a/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/Pet.java b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/Pet.java new file mode 100644 index 0000000000..0c6cfe1adb --- /dev/null +++ b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/Pet.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.petstore; + +public class Pet { + private int id; + private String name; + + private Pet(){} + public Pet(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + public void setId(int id) { + this.id = id; + } + + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } +} diff --git a/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreController.java b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreController.java new file mode 100644 index 0000000000..a972413f27 --- /dev/null +++ b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreController.java @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.petstore; + +import akka.http.javadsl.server.RequestContext; +import akka.http.javadsl.server.RouteResult; + +import java.util.Map; + +public class PetStoreController { + private Map dataStore; + + public PetStoreController(Map dataStore) { + this.dataStore = dataStore; + } + public RouteResult deletePet(RequestContext ctx, int petId) { + dataStore.remove(petId); + return ctx.completeWithStatus(200); + } +} diff --git a/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreExample.java b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreExample.java new file mode 100644 index 0000000000..ce465679f7 --- /dev/null +++ b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/petstore/PetStoreExample.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.petstore; + +import akka.actor.ActorSystem; +import akka.http.javadsl.marshallers.jackson.Jackson; +import akka.http.javadsl.server.*; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static akka.http.javadsl.server.Directives.*; + +public class PetStoreExample { + static PathMatcher petId = PathMatchers.integerNumber(); + static RequestVal petEntity = RequestVals.entityAs(Jackson.jsonAs(Pet.class)); + + public static Route appRoute(final Map pets) { + PetStoreController controller = new PetStoreController(pets); + + final RequestVal existingPet = RequestVals.lookupInMap(petId, Pet.class, pets); + + Handler1 putPetHandler = new Handler1() { + public RouteResult handle(RequestContext ctx, Pet thePet) { + pets.put(thePet.getId(), thePet); + return ctx.completeAs(Jackson.json(), thePet); + } + }; + + return + route( + path().route( + getFromResource("web/index.html") + ), + path("pet", petId).route( + // demonstrates three different ways of handling requests: + + // 1. using a predefined route that completes with an extraction + get(extractAndComplete(Jackson.json(), existingPet)), + + // 2. using a handler + put(handleWith(petEntity, putPetHandler)), + + // 3. calling a method of a controller instance reflectively + delete(handleWith(controller, "deletePet", petId)) + ) + ); + } + + public static void main(String[] args) throws IOException { + Map pets = new ConcurrentHashMap(); + Pet dog = new Pet(0, "dog"); + Pet cat = new Pet(1, "cat"); + pets.put(0, dog); + pets.put(1, cat); + + ActorSystem system = ActorSystem.create(); + try { + HttpService.bindRoute("localhost", 8080, appRoute(pets), system); + System.out.println("Type RETURN to exit"); + System.in.read(); + } finally { + system.shutdown(); + } + } +} \ No newline at end of file diff --git a/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/simple/SimpleServerApp.java b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/simple/SimpleServerApp.java new file mode 100644 index 0000000000..df2ee240e2 --- /dev/null +++ b/akka-http-tests/src/main/java/akka/http/javadsl/server/examples/simple/SimpleServerApp.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.simple; + +import akka.actor.ActorSystem; +import akka.http.javadsl.server.*; + +import java.io.IOException; + +public class SimpleServerApp extends HttpApp { + static Parameter x = Parameters.integer("x"); + static Parameter y = Parameters.integer("y"); + + static PathMatcher xSegment = PathMatchers.integerNumber(); + static PathMatcher ySegment = PathMatchers.integerNumber(); + + public static RouteResult multiply(RequestContext ctx, int x, int y) { + int result = x * y; + return ctx.complete(String.format("%d * %d = %d", x, y, result)); + } + + @Override + public Route createRoute() { + Handler addHandler = new Handler() { + @Override + public RouteResult handle(RequestContext ctx) { + int xVal = x.get(ctx); + int yVal = y.get(ctx); + int result = xVal + yVal; + return ctx.complete(String.format("%d + %d = %d", xVal, yVal, result)); + } + }; + Handler2 subtractHandler = new Handler2() { + public RouteResult handle(RequestContext ctx, Integer xVal, Integer yVal) { + int result = xVal - yVal; + return ctx.complete(String.format("%d - %d = %d", xVal, yVal, result)); + } + }; + return + route( + // matches the empty path + pathSingleSlash().route( + getFromResource("web/calculator.html") + ), + // matches paths like this: /add?x=42&y=23 + path("add").route( + handleWith(addHandler, x, y) + ), + path("subtract").route( + handleWith(x, y, subtractHandler) + ), + // matches paths like this: /multiply/{x}/{y} + path("multiply", xSegment, ySegment).route( + // bind handler by reflection + handleWith(SimpleServerApp.class, "multiply", xSegment, ySegment) + ) + ); + } + + public static void main(String[] args) throws IOException { + ActorSystem system = ActorSystem.create(); + new SimpleServerApp().bindRoute("localhost", 8080, system); + System.out.println("Type RETURN to exit"); + System.in.read(); + system.shutdown(); + } +} \ No newline at end of file diff --git a/akka-http-tests/src/main/resources/web/calculator.html b/akka-http-tests/src/main/resources/web/calculator.html new file mode 100644 index 0000000000..a32b054287 --- /dev/null +++ b/akka-http-tests/src/main/resources/web/calculator.html @@ -0,0 +1,23 @@ + + +

Calculator

+ +

Add

+
+ + + +
+ +

Subtract

+
+ + + +
+ +

Multiply

+/multiply/42/23 + + + \ No newline at end of file diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/AuthenticationDirectivesTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/AuthenticationDirectivesTest.java new file mode 100644 index 0000000000..ae3e0e5b4b --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/AuthenticationDirectivesTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import org.junit.Test; +import scala.Option; +import scala.concurrent.Future; +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.model.headers.Authorization; +import akka.http.javadsl.testkit.*; +import static akka.http.javadsl.server.Directives.*; + +public class AuthenticationDirectivesTest extends JUnitRouteTest { + HttpBasicAuthenticator authenticatedUser = + new HttpBasicAuthenticator("test-realm") { + @Override + public Future> authenticate(BasicUserCredentials credentials) { + if (credentials.available() && // no anonymous access + credentials.userName().equals("sina") && + credentials.verifySecret("1234")) + return authenticateAs("Sina"); + else return refuseAccess(); + } + }; + + Handler1 helloWorldHandler = + new Handler1() { + @Override + public RouteResult handle(RequestContext ctx, String user) { + return ctx.complete("Hello "+user+"!"); + } + }; + + TestRoute route = + testRoute( + path("secure").route( + authenticatedUser.route( + handleWith(authenticatedUser, helloWorldHandler) + ) + ) + ); + + @Test + public void testCorrectUser() { + HttpRequest authenticatedRequest = + HttpRequest.GET("/secure") + .addHeader(Authorization.basic("sina", "1234")); + + route.run(authenticatedRequest) + .assertStatusCode(200) + .assertEntity("Hello Sina!"); + } + @Test + public void testRejectAnonymousAccess() { + route.run(HttpRequest.GET("/secure")) + .assertStatusCode(401) + .assertEntity("The resource requires authentication, which was not supplied with the request") + .assertHeaderExists("WWW-Authenticate", "Basic realm=\"test-realm\""); + } + @Test + public void testRejectUnknownUser() { + HttpRequest authenticatedRequest = + HttpRequest.GET("/secure") + .addHeader(Authorization.basic("joe", "0000")); + + route.run(authenticatedRequest) + .assertStatusCode(401) + .assertEntity("The supplied authentication is invalid"); + } + @Test + public void testRejectWrongPassword() { + HttpRequest authenticatedRequest = + HttpRequest.GET("/secure") + .addHeader(Authorization.basic("sina", "1235")); + + route.run(authenticatedRequest) + .assertStatusCode(401) + .assertEntity("The supplied authentication is invalid"); + } +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/CodingDirectivesTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/CodingDirectivesTest.java new file mode 100644 index 0000000000..d21a6f367b --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/CodingDirectivesTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import static akka.http.javadsl.server.Directives.*; + +import akka.actor.ActorSystem; +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.model.headers.AcceptEncoding; +import akka.http.javadsl.model.headers.ContentEncoding; +import akka.http.javadsl.model.headers.HttpEncodings; +import akka.stream.ActorFlowMaterializer; +import akka.util.ByteString; +import org.junit.*; +import scala.concurrent.Await; +import scala.concurrent.duration.Duration; +import akka.http.javadsl.testkit.*; +import java.util.concurrent.TimeUnit; + +public class CodingDirectivesTest extends JUnitRouteTest { + + static ActorSystem system; + + @BeforeClass + public static void setup() { + system = ActorSystem.create("FlowGraphDocTest"); + } + + @AfterClass + public static void tearDown() { + system.shutdown(); + system.awaitTermination(); + system = null; + } + + final ActorFlowMaterializer mat = ActorFlowMaterializer.create(system); + + @Test + public void testAutomaticEncodingWhenNoEncodingRequested() throws Exception { + TestRoute route = + testRoute( + encodeResponse( + complete("TestString") + ) + ); + + TestResponse response = route.run(HttpRequest.create()); + response + .assertStatusCode(200); + + Assert.assertEquals("TestString", response.entityBytes().utf8String()); + } + @Test + public void testAutomaticEncodingWhenDeflateRequested() throws Exception { + TestRoute route = + testRoute( + encodeResponse( + complete("tester") + ) + ); + + HttpRequest request = HttpRequest.create().addHeader(AcceptEncoding.create(HttpEncodings.DEFLATE)); + TestResponse response = route.run(request); + response + .assertStatusCode(200) + .assertHeaderExists(ContentEncoding.create(HttpEncodings.DEFLATE)); + + ByteString decompressed = + Await.result(Coder.Deflate.decode(response.entityBytes(), mat), Duration.apply(3, TimeUnit.SECONDS)); + Assert.assertEquals("tester", decompressed.utf8String()); + } + @Test + public void testEncodingWhenDeflateRequestedAndGzipSupported() { + TestRoute route = + testRoute( + encodeResponse(Coder.Gzip).route( + complete("tester") + ) + ); + + HttpRequest request = HttpRequest.create().addHeader(AcceptEncoding.create(HttpEncodings.DEFLATE)); + route.run(request) + .assertStatusCode(406) + .assertEntity("Resource representation is only available with these Content-Encodings:\ngzip"); + } + + @Test + public void testAutomaticDecoding() {} + @Test + public void testGzipDecoding() {} +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/CompleteTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/CompleteTest.java new file mode 100644 index 0000000000..2d410a958d --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/CompleteTest.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import akka.dispatch.Futures; +import akka.http.javadsl.marshallers.jackson.Jackson; +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.model.MediaTypes; +import org.junit.Test; +import akka.http.javadsl.testkit.*; +import java.util.concurrent.Callable; + +import static akka.http.javadsl.server.Directives.*; + +public class CompleteTest extends JUnitRouteTest { + @Test + public void completeWithString() { + Route route = complete("Everything OK!"); + + HttpRequest request = HttpRequest.create(); + + runRoute(route, request) + .assertStatusCode(200) + .assertMediaType(MediaTypes.TEXT_PLAIN) + .assertEntity("Everything OK!"); + } + + @Test + public void completeAsJacksonJson() { + class Person { + public String getFirstName() { return "Peter"; } + public String getLastName() { return "Parker"; } + public int getAge() { return 138; } + } + Route route = completeAs(Jackson.json(), new Person()); + + HttpRequest request = HttpRequest.create(); + + runRoute(route, request) + .assertStatusCode(200) + .assertMediaType("application/json") + .assertEntity("{\"age\":138,\"firstName\":\"Peter\",\"lastName\":\"Parker\"}"); + } + @Test + public void completeWithFuture() { + Parameter x = Parameters.integer("x"); + Parameter y = Parameters.integer("y"); + + Handler2 slowCalc = new Handler2() { + @Override + public RouteResult handle(final RequestContext ctx, final Integer x, final Integer y) { + return ctx.completeWith(Futures.future(new Callable() { + @Override + public RouteResult call() throws Exception { + int result = x + y; + return ctx.complete(String.format("%d + %d = %d",x, y, result)); + } + }, executionContext())); + } + }; + + Route route = handleWith(x, y, slowCalc); + runRoute(route, HttpRequest.GET("add?x=42&y=23")) + .assertStatusCode(200) + .assertEntity("42 + 23 = 65"); + } +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/HandlerBindingTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/HandlerBindingTest.java new file mode 100644 index 0000000000..508ef336cf --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/HandlerBindingTest.java @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import akka.http.scaladsl.model.HttpRequest; +import org.junit.Test; +import akka.http.javadsl.testkit.*; +import static akka.http.javadsl.server.Directives.*; + +public class HandlerBindingTest extends JUnitRouteTest { + @Test + public void testHandlerWithoutExtractions() { + Route route = handleWith( + new Handler() { + @Override + public RouteResult handle(RequestContext ctx) { + return ctx.complete("Ok"); + } + } + ); + runRoute(route, HttpRequest.GET("/")) + .assertEntity("Ok"); + } + @Test + public void testHandlerWithSomeExtractions() { + final Parameter a = Parameters.integer("a"); + final Parameter b = Parameters.integer("b"); + + Route route = handleWith( + new Handler() { + @Override + public RouteResult handle(RequestContext ctx) { + return ctx.complete("Ok a:" + a.get(ctx) +" b:" + b.get(ctx)); + } + }, a, b + ); + runRoute(route, HttpRequest.GET("?a=23&b=42")) + .assertEntity("Ok a:23 b:42"); + } + @Test + public void testHandlerIfExtractionFails() { + final Parameter a = Parameters.integer("a"); + + Route route = handleWith( + new Handler() { + @Override + public RouteResult handle(RequestContext ctx) { + return ctx.complete("Ok " + a.get(ctx)); + } + }, a + ); + runRoute(route, HttpRequest.GET("/")) + .assertStatusCode(404) + .assertEntity("Request is missing required query parameter 'a'"); + } + @Test + public void testHandler1() { + final Parameter a = Parameters.integer("a"); + + Route route = handleWith( + a, + new Handler1() { + @Override + public RouteResult handle(RequestContext ctx, Integer a) { + return ctx.complete("Ok " + a); + } + } + ); + runRoute(route, HttpRequest.GET("?a=23")) + .assertStatusCode(200) + .assertEntity("Ok 23"); + } + @Test + public void testHandler2() { + Route route = handleWith( + Parameters.integer("a"), + Parameters.integer("b"), + new Handler2() { + @Override + public RouteResult handle(RequestContext ctx, Integer a, Integer b) { + return ctx.complete("Sum: " + (a + b)); + } + } + ); + runRoute(route, HttpRequest.GET("?a=23&b=42")) + .assertStatusCode(200) + .assertEntity("Sum: 65"); + } + @Test + public void testHandler3() { + Route route = handleWith( + Parameters.integer("a"), + Parameters.integer("b"), + Parameters.integer("c"), + new Handler3() { + @Override + public RouteResult handle(RequestContext ctx, Integer a, Integer b, Integer c) { + return ctx.complete("Sum: " + (a + b + c)); + } + } + ); + TestResponse response = runRoute(route, HttpRequest.GET("?a=23&b=42&c=30")); + response.assertStatusCode(200); + response.assertEntity("Sum: 95"); + } + @Test + public void testHandler4() { + Route route = handleWith( + Parameters.integer("a"), + Parameters.integer("b"), + Parameters.integer("c"), + Parameters.integer("d"), + new Handler4() { + @Override + public RouteResult handle(RequestContext ctx, Integer a, Integer b, Integer c, Integer d) { + return ctx.complete("Sum: " + (a + b + c + d)); + } + } + ); + runRoute(route, HttpRequest.GET("?a=23&b=42&c=30&d=45")) + .assertStatusCode(200) + .assertEntity("Sum: 140"); + } + @Test + public void testReflectiveInstanceHandler() { + class Test { + public RouteResult negate(RequestContext ctx, int a) { + return ctx.complete("Negated: " + (- a)); + } + } + Route route = handleWith(new Test(), "negate", Parameters.integer("a")); + runRoute(route, HttpRequest.GET("?a=23")) + .assertStatusCode(200) + .assertEntity("Negated: -23"); + } + + public static RouteResult squared(RequestContext ctx, int a) { + return ctx.complete("Squared: " + (a * a)); + } + @Test + public void testStaticReflectiveHandler() { + Route route = handleWith(HandlerBindingTest.class, "squared", Parameters.integer("a")); + runRoute(route, HttpRequest.GET("?a=23")) + .assertStatusCode(200) + .assertEntity("Squared: 529"); + } +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/PathDirectivesTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/PathDirectivesTest.java new file mode 100644 index 0000000000..61ba4571dc --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/PathDirectivesTest.java @@ -0,0 +1,222 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import static akka.http.javadsl.server.Directives.*; +import akka.http.javadsl.testkit.*; +import akka.http.scaladsl.model.HttpRequest; +import org.junit.Test; + +import java.util.List; +import java.util.UUID; + +public class PathDirectivesTest extends JUnitRouteTest { + @Test + public void testPathPrefixAndPath() { + TestRoute route = + testRoute( + pathPrefix("pet").route( + path("cat").route(complete("The cat!")), + path("dog").route(complete("The dog!")), + pathSingleSlash().route(complete("Here are only pets.")) + ) + ); + + route.run(HttpRequest.GET("/pet/")) + .assertEntity("Here are only pets."); + + route.run(HttpRequest.GET("/pet")) // missing trailing slash + .assertStatusCode(404); + + route.run(HttpRequest.GET("/pet/cat")) + .assertEntity("The cat!"); + + route.run(HttpRequest.GET("/pet/dog")) + .assertEntity("The dog!"); + } + + @Test + public void testRawPathPrefix() { + TestRoute route1 = + testRoute( + rawPathPrefix(PathMatchers.SLASH(), "pet", PathMatchers.SLASH(), "", PathMatchers.SLASH(), "cat").route( + complete("The cat!") + ) + ); + + route1.run(HttpRequest.GET("/pet//cat")) + .assertEntity("The cat!"); + + // any suffix allowed + route1.run(HttpRequest.GET("/pet//cat/abcdefg")) + .assertEntity("The cat!"); + + TestRoute route2 = + testRoute( + rawPathPrefix(PathMatchers.SLASH(), "pet", PathMatchers.SLASH(), "", PathMatchers.SLASH(), "cat", PathMatchers.END()).route( + complete("The cat!") + ) + ); + + route2.run(HttpRequest.GET("/pet//cat")) + .assertEntity("The cat!"); + + route2.run(HttpRequest.GET("/pet//cat/abcdefg")) + .assertStatusCode(404); + } + + @Test + public void testSegment() { + PathMatcher name = PathMatchers.segment(); + + TestRoute route = + testRoute( + path("hey", name).route(toStringEcho(name)) + ); + + route.run(HttpRequest.GET("/hey/jude")) + .assertEntity("jude"); + } + + @Test + public void testSingleSlash() { + TestRoute route = + testRoute( + pathSingleSlash().route(complete("Ok")) + ); + + route.run(HttpRequest.GET("/")) + .assertEntity("Ok"); + + route.run(HttpRequest.GET("/abc")) + .assertStatusCode(404); + } + + @Test + public void testIntegerMatcher() { + PathMatcher age = PathMatchers.integerNumber(); + + TestRoute route = + testRoute( + path("age", age).route(toStringEcho(age)) + ); + + route.run(HttpRequest.GET("/age/38")) + .assertEntity("38"); + + route.run(HttpRequest.GET("/age/abc")) + .assertStatusCode(404); + } + + @Test + public void testTwoVals() { + // tests that `x` and `y` have different identities which is important for + // retrieving the values + PathMatcher x = PathMatchers.integerNumber(); + PathMatcher y = PathMatchers.integerNumber(); + + TestRoute route = + testRoute( + path("multiply", x, y).route( + handleWith(x, y, new Handler2() { + @Override + public RouteResult handle(RequestContext ctx, Integer x, Integer y) { + return ctx.complete(String.format("%d * %d = %d", x, y, x * y)); + } + }) + ) + ); + + route.run(HttpRequest.GET("/multiply/3/6")) + .assertEntity("3 * 6 = 18"); + } + + @Test + public void testHexIntegerMatcher() { + PathMatcher color = PathMatchers.hexIntegerNumber(); + + TestRoute route = + testRoute( + path("color", color).route(toStringEcho(color)) + ); + + route.run(HttpRequest.GET("/color/a0c2ef")) + .assertEntity(Integer.toString(0xa0c2ef)); + } + + @Test + public void testLongMatcher() { + PathMatcher bigAge = PathMatchers.longNumber(); + + TestRoute route = + testRoute( + path("bigage", bigAge).route(toStringEcho(bigAge)) + ); + + route.run(HttpRequest.GET("/bigage/12345678901")) + .assertEntity("12345678901"); + } + + @Test + public void testHexLongMatcher() { + PathMatcher code = PathMatchers.hexLongNumber(); + + TestRoute route = + testRoute( + path("code", code).route(toStringEcho(code)) + ); + + route.run(HttpRequest.GET("/code/a0b1c2d3e4f5")) + .assertEntity(Long.toString(0xa0b1c2d3e4f5L)); + } + + @Test + public void testRestMatcher() { + PathMatcher rest = PathMatchers.rest(); + + TestRoute route = + testRoute( + path("pets", rest).route(toStringEcho(rest)) + ); + + route.run(HttpRequest.GET("/pets/afdaoisd/asda/sfasfasf/asf")) + .assertEntity("afdaoisd/asda/sfasfasf/asf"); + } + + @Test + public void testUUIDMatcher() { + PathMatcher uuid = PathMatchers.uuid(); + + TestRoute route = + testRoute( + path("by-uuid", uuid).route(toStringEcho(uuid)) + ); + + route.run(HttpRequest.GET("/by-uuid/6ba7b811-9dad-11d1-80b4-00c04fd430c8")) + .assertEntity("6ba7b811-9dad-11d1-80b4-00c04fd430c8"); + } + + @Test + public void testSegmentsMatcher() { + PathMatcher> segments = PathMatchers.segments(); + + TestRoute route = + testRoute( + path("pets", segments).route(toStringEcho(segments)) + ); + + route.run(HttpRequest.GET("/pets/cat/dog")) + .assertEntity("[cat, dog]"); + } + + private Route toStringEcho(RequestVal value) { + return handleWith(value, new Handler1() { + @Override + public RouteResult handle(RequestContext ctx, T t) { + return ctx.complete(t.toString()); + } + }); + } +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/petstore/PetStoreAPITest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/petstore/PetStoreAPITest.java new file mode 100644 index 0000000000..15b8d4bbe9 --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/petstore/PetStoreAPITest.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.petstore; + +import akka.http.javadsl.marshallers.jackson.Jackson; +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.model.MediaTypes; +import akka.http.javadsl.testkit.*; +import static org.junit.Assert.*; + +import akka.http.javadsl.testkit.TestRoute; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +public class PetStoreAPITest extends JUnitRouteTest { + @Test + public void testGetPet() { + TestResponse response = createRoute().run(HttpRequest.GET("/pet/1")); + + response + .assertStatusCode(200) + .assertMediaType("application/json"); + + Pet pet = response.entityAs(Jackson.jsonAs(Pet.class)); + assertEquals("cat", pet.getName()); + assertEquals(1, pet.getId()); + } + @Test + public void testGetMissingPet() { + createRoute().run(HttpRequest.GET("/pet/999")) + .assertStatusCode(404); + } + @Test + public void testPutPet() { + HttpRequest request = + HttpRequest.PUT("/pet/1") + .withEntity(MediaTypes.APPLICATION_JSON.toContentType(), "{\"id\": 1, \"name\": \"giraffe\"}"); + + TestResponse response = createRoute().run(request); + + response.assertStatusCode(200); + + Pet pet = response.entityAs(Jackson.jsonAs(Pet.class)); + assertEquals("giraffe", pet.getName()); + assertEquals(1, pet.getId()); + } + @Test + public void testDeletePet() { + Map data = createData(); + + HttpRequest request = HttpRequest.DELETE("/pet/0"); + + createRoute(data).run(request) + .assertStatusCode(200); + + // test actual deletion from data store + assertFalse(data.containsKey(0)); + } + + private TestRoute createRoute() { + return createRoute(createData()); + } + private TestRoute createRoute(Map pets) { + return testRoute(PetStoreExample.appRoute(pets)); + } + private Map createData() { + Map pets = new HashMap(); + Pet dog = new Pet(0, "dog"); + Pet cat = new Pet(1, "cat"); + pets.put(0, dog); + pets.put(1, cat); + + return pets; + } +} diff --git a/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/simple/SimpleServerTest.java b/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/simple/SimpleServerTest.java new file mode 100644 index 0000000000..6c5e23fa5a --- /dev/null +++ b/akka-http-tests/src/test/java/akka/http/javadsl/server/examples/simple/SimpleServerTest.java @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.examples.simple; + +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.testkit.*; +import org.junit.Test; + +public class SimpleServerTest extends JUnitRouteTest { + TestRoute route = testRoute(new SimpleServerApp().createRoute()); + + @Test + public void testAdd() { + TestResponse response = route.run(HttpRequest.GET("/add?x=42&y=23")); + + response + .assertStatusCode(200) + .assertEntity("42 + 23 = 65"); + } +} diff --git a/akka-http-tests/src/test/resources/sample.html b/akka-http-tests/src/test/resources/sample.html new file mode 100644 index 0000000000..10dbdec8c5 --- /dev/null +++ b/akka-http-tests/src/test/resources/sample.html @@ -0,0 +1 @@ +

Lorem ipsum!

\ No newline at end of file diff --git a/akka-http-tests/src/test/resources/sample.xyz b/akka-http-tests/src/test/resources/sample.xyz new file mode 100644 index 0000000000..ce42064770 --- /dev/null +++ b/akka-http-tests/src/test/resources/sample.xyz @@ -0,0 +1 @@ +XyZ \ No newline at end of file diff --git a/akka-http-tests/src/test/resources/someDir/fileA.txt b/akka-http-tests/src/test/resources/someDir/fileA.txt new file mode 100644 index 0000000000..d800886d9c --- /dev/null +++ b/akka-http-tests/src/test/resources/someDir/fileA.txt @@ -0,0 +1 @@ +123 \ No newline at end of file diff --git a/akka-http-tests/src/test/resources/someDir/fileB.xml b/akka-http-tests/src/test/resources/someDir/fileB.xml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/akka-http-tests/src/test/resources/someDir/sub/file.html b/akka-http-tests/src/test/resources/someDir/sub/file.html new file mode 100644 index 0000000000..e69de29bb2 diff --git a/akka-http-tests/src/test/resources/subDirectory/empty.pdf b/akka-http-tests/src/test/resources/subDirectory/empty.pdf new file mode 100644 index 0000000000..d800886d9c --- /dev/null +++ b/akka-http-tests/src/test/resources/subDirectory/empty.pdf @@ -0,0 +1 @@ +123 \ No newline at end of file diff --git a/akka-http-tests/src/test/resources/subDirectory/fileA.txt b/akka-http-tests/src/test/resources/subDirectory/fileA.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/FormDataSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/FormDataSpec.scala new file mode 100644 index 0000000000..d4e569bb30 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/FormDataSpec.scala @@ -0,0 +1,42 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl + +import scala.concurrent.duration._ +import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpec } +import org.scalatest.concurrent.ScalaFutures +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.http.scaladsl.marshalling.Marshal +import akka.http.scaladsl.model._ + +class FormDataSpec extends WordSpec with Matchers with ScalaFutures with BeforeAndAfterAll { + implicit val system = ActorSystem(getClass.getSimpleName) + implicit val materializer = ActorFlowMaterializer() + import system.dispatcher + + val formData = FormData(Map("surname" -> "Smith", "age" -> "42")) + + "The FormData infrastructure" should { + "properly round-trip the fields of www-urlencoded forms" in { + Marshal(formData).to[HttpEntity] + .flatMap(Unmarshal(_).to[FormData]).futureValue shouldEqual formData + } + + "properly marshal www-urlencoded forms containing special chars" in { + Marshal(FormData(Map("name" -> "Smith&Wesson"))).to[HttpEntity] + .flatMap(Unmarshal(_).to[String]).futureValue shouldEqual "name=Smith%26Wesson" + + Marshal(FormData(Map("name" -> "Smith+Wesson; hopefully!"))).to[HttpEntity] + .flatMap(Unmarshal(_).to[String]).futureValue shouldEqual "name=Smith%2BWesson%3B+hopefully%21" + } + } + + override def afterAll() = { + system.shutdown() + system.awaitTermination(10.seconds) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CodecSpecSupport.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CodecSpecSupport.scala new file mode 100644 index 0000000000..a5c46cfe93 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CodecSpecSupport.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import scala.concurrent.duration._ +import org.scalatest.{ Suite, BeforeAndAfterAll, Matchers } +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.util.ByteString + +trait CodecSpecSupport extends Matchers with BeforeAndAfterAll { self: Suite ⇒ + + def readAs(string: String, charset: String = "UTF8") = equal(string).matcher[String] compose { (_: ByteString).decodeString(charset) } + def hexDump(bytes: ByteString) = bytes.map("%02x".format(_)).mkString + def fromHexDump(dump: String) = dump.grouped(2).toArray.map(chars ⇒ Integer.parseInt(new String(chars), 16).toByte) + + def printBytes(i: Int, id: String) = { + def byte(i: Int) = (i & 0xFF).toHexString + println(id + ": " + byte(i) + ":" + byte(i >> 8) + ":" + byte(i >> 16) + ":" + byte(i >> 24)) + i + } + + lazy val smallTextBytes = ByteString(smallText, "UTF8") + lazy val largeTextBytes = ByteString(largeText, "UTF8") + + val smallText = "Yeah!" + val largeText = + """Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore +magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd +gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing +elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos +et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor +sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et +dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd +gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. + +Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu feugiat +nulla facilisis at vero eros et accumsan et iusto odio dignissim qui blandit praesent luptatum zzril delenit augue duis +dolore te feugait nulla facilisi. Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh +euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. + +Ut wisi enim ad minim veniam, quis nostrud exerci tation ullamcorper suscipit lobortis nisl ut aliquip ex ea commodo +consequat. Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu +feugiat nulla facilisis at vero eros et accumsan et iusto odio dignissim qui blandit praesent luptatum zzril delenit +augue duis dolore te feugait nulla facilisi. + +Nam liber tempor cum soluta nobis eleifend option congue nihil imperdiet doming id quod mazim placerat facer possim +assum. Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet +dolore magna aliquam erat volutpat. Ut wisi enim ad minim veniam, quis nostrud exerci tation ullamcorper suscipit +lobortis nisl ut aliquip ex ea commodo consequat. + +Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu feugiat +nulla facilisis. + +At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem +ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt +ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. +Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, +consetetur sadipscing elitr, At accusam aliquyam diam diam dolore dolores duo eirmod eos erat, et nonumy sed tempor et +et invidunt justo labore Stet clita ea et gubergren, kasd magna no rebum. sanctus sea sed takimata ut vero voluptua. +est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor +invidunt ut labore et dolore magna aliquyam erat. + +Consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam +voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus +est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy e""".replace("\r\n", "\n") + + implicit val system = ActorSystem(getClass.getSimpleName) + implicit val materializer = ActorFlowMaterializer() + + override def afterAll() = { + system.shutdown() + system.awaitTermination(10.seconds) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CoderSpec.scala new file mode 100644 index 0000000000..32f9448bc7 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/CoderSpec.scala @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import java.io.{ OutputStream, InputStream, ByteArrayInputStream, ByteArrayOutputStream } +import java.util +import java.util.zip.DataFormatException +import scala.annotation.tailrec +import scala.concurrent.duration._ +import scala.concurrent.Await +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.control.NoStackTrace +import org.scalatest.{ Inspectors, WordSpec } +import akka.util.ByteString +import akka.stream.scaladsl.{ Sink, Source } +import akka.http.scaladsl.model.{ HttpEntity, HttpRequest } +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.impl.util._ + +abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors { + protected def Coder: Coder with StreamDecoder + protected def newDecodedInputStream(underlying: InputStream): InputStream + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream + + case object AllDataAllowed extends Exception with NoStackTrace + protected def corruptInputCheck: Boolean = true + + def extraTests(): Unit = {} + + s"The ${Coder.encoding.value} codec" should { + "properly encode a small string" in { + streamDecode(ourEncode(smallTextBytes)) should readAs(smallText) + } + "properly decode a small string" in { + ourDecode(streamEncode(smallTextBytes)) should readAs(smallText) + } + "properly round-trip encode/decode a small string" in { + ourDecode(ourEncode(smallTextBytes)) should readAs(smallText) + } + "properly encode a large string" in { + streamDecode(ourEncode(largeTextBytes)) should readAs(largeText) + } + "properly decode a large string" in { + ourDecode(streamEncode(largeTextBytes)) should readAs(largeText) + } + "properly round-trip encode/decode a large string" in { + ourDecode(ourEncode(largeTextBytes)) should readAs(largeText) + } + "properly round-trip encode/decode an HttpRequest" in { + val request = HttpRequest(POST, entity = HttpEntity(largeText)) + Coder.decode(Coder.encode(request)).toStrict(1.second).awaitResult(1.second) should equal(request) + } + + if (corruptInputCheck) { + "throw an error on corrupt input" in { + (the[RuntimeException] thrownBy { + ourDecode(corruptContent) + }).getCause should be(a[DataFormatException]) + } + } + + "not throw an error if a subsequent block is corrupt" in { + pending // FIXME: should we read as long as possible and only then report an error, that seems somewhat arbitrary + ourDecode(Seq(encode("Hello,"), encode(" dear "), corruptContent).join) should readAs("Hello, dear ") + } + "decompress in very small chunks" in { + val compressed = encode("Hello") + + decodeChunks(Source(Vector(compressed.take(10), compressed.drop(10)))) should readAs("Hello") + } + "support chunked round-trip encoding/decoding" in { + val chunks = largeTextBytes.grouped(512).toVector + val comp = Coder.newCompressor + val compressedChunks = chunks.map { chunk ⇒ comp.compressAndFlush(chunk) } :+ comp.finish() + val uncompressed = decodeFromIterator(() ⇒ compressedChunks.iterator) + + uncompressed should readAs(largeText) + } + "works for any split in prefix + suffix" in { + val compressed = streamEncode(smallTextBytes) + def tryWithPrefixOfSize(prefixSize: Int): Unit = { + val prefix = compressed.take(prefixSize) + val suffix = compressed.drop(prefixSize) + + decodeChunks(Source(prefix :: suffix :: Nil)) should readAs(smallText) + } + (0 to compressed.size).foreach(tryWithPrefixOfSize) + } + "works for chunked compressed data of sizes just above 1024" in { + val comp = Coder.newCompressor + val inputBytes = ByteString("""{"baseServiceURL":"http://www.acme.com","endpoints":{"assetSearchURL":"/search","showsURL":"/shows","mediaContainerDetailURL":"/container","featuredTapeURL":"/tape","assetDetailURL":"/asset","moviesURL":"/movies","recentlyAddedURL":"/recent","topicsURL":"/topics","scheduleURL":"/schedule"},"urls":{"aboutAweURL":"www.foobar.com"},"channelName":"Cool Stuff","networkId":"netId","slotProfile":"slot_1","brag":{"launchesUntilPrompt":10,"daysUntilPrompt":5,"launchesUntilReminder":5,"daysUntilReminder":2},"feedbackEmailAddress":"feedback@acme.com","feedbackEmailSubject":"Commends from User","splashSponsor":[],"adProvider":{"adProviderProfile":"","adProviderProfileAndroid":"","adProviderNetworkID":0,"adProviderSiteSectionNetworkID":0,"adProviderVideoAssetNetworkID":0,"adProviderSiteSectionCustomID":{},"adProviderServerURL":"","adProviderLiveVideoAssetID":""},"update":[{"forPlatform":"ios","store":{"iTunes":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"},{"forPlatform":"android","store":{"amazon":"www.something.com","play":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"}],"tvRatingPolicies":[{"type":"sometype","imageKey":"tv_rating_small","durationMS":15000,"precedence":1},{"type":"someothertype","imageKey":"tv_rating_big","durationMS":15000,"precedence":2}],"exts":{"adConfig":{"globals":{"#{adNetworkID}":"2620","#{ssid}":"usa_tveapp"},"iPad":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPadRetina":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhoneRetina":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Tablet":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"TabletHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Phone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"PhoneHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}}}}}""", "utf8") + val compressed = comp.compressAndFinish(inputBytes) + + ourDecode(compressed) should equal(inputBytes) + } + + extraTests() + + "shouldn't produce huge ByteStrings for some input" in { + val array = new Array[Byte](10) // FIXME + util.Arrays.fill(array, 1.toByte) + val compressed = streamEncode(ByteString(array)) + val limit = 10000 + val resultBs = + Source.single(compressed) + .via(Coder.withMaxBytesPerChunk(limit).decoderFlow) + .grouped(4200).runWith(Sink.head) + .awaitResult(1.second) + + forAll(resultBs) { bs ⇒ + bs.length should be < limit + bs.forall(_ == 1) should equal(true) + } + } + } + + def encode(s: String) = ourEncode(ByteString(s, "UTF8")) + def ourEncode(bytes: ByteString): ByteString = Coder.encode(bytes) + def ourDecode(bytes: ByteString): ByteString = Coder.decode(bytes).awaitResult(1.second) + + lazy val corruptContent = { + val content = encode(largeText).toArray + content(14) = 26.toByte + ByteString(content) + } + + def streamEncode(bytes: ByteString): ByteString = { + val output = new ByteArrayOutputStream() + val gos = newEncodedOutputStream(output); gos.write(bytes.toArray); gos.close() + ByteString(output.toByteArray) + } + + def streamDecode(bytes: ByteString): ByteString = { + val output = new ByteArrayOutputStream() + val input = newDecodedInputStream(new ByteArrayInputStream(bytes.toArray)) + + val buffer = new Array[Byte](500) + @tailrec def copy(from: InputStream, to: OutputStream): Unit = { + val read = from.read(buffer) + if (read >= 0) { + to.write(buffer, 0, read) + copy(from, to) + } + } + + copy(input, output) + ByteString(output.toByteArray) + } + + def decodeChunks(input: Source[ByteString, _]): ByteString = + input.via(Coder.decoderFlow).join.awaitResult(3.seconds) + + def decodeFromIterator(iterator: () ⇒ Iterator[ByteString]): ByteString = + Await.result(Source(iterator).via(Coder.decoderFlow).join, 3.seconds) +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala new file mode 100644 index 0000000000..8b3dd3d102 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DecoderSpec.scala @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import scala.concurrent.duration._ +import org.scalatest.WordSpec +import akka.util.ByteString +import akka.stream.stage.{ SyncDirective, Context, PushStage, Stage } +import akka.http.scaladsl.model._ +import akka.http.impl.util._ +import headers._ +import HttpMethods.POST + +class DecoderSpec extends WordSpec with CodecSpecSupport { + + "A Decoder" should { + "not transform the message if it doesn't contain a Content-Encoding header" in { + val request = HttpRequest(POST, entity = HttpEntity(smallText)) + DummyDecoder.decode(request) shouldEqual request + } + "correctly transform the message if it contains a Content-Encoding header" in { + val request = HttpRequest(POST, entity = HttpEntity(smallText), headers = List(`Content-Encoding`(DummyDecoder.encoding))) + val decoded = DummyDecoder.decode(request) + decoded.headers shouldEqual Nil + decoded.entity.toStrict(1.second).awaitResult(1.second) shouldEqual HttpEntity(dummyDecompress(smallText)) + } + } + + def dummyDecompress(s: String): String = dummyDecompress(ByteString(s, "UTF8")).decodeString("UTF8") + def dummyDecompress(bytes: ByteString): ByteString = DummyDecoder.decode(bytes).awaitResult(1.second) + + case object DummyDecoder extends StreamDecoder { + val encoding = HttpEncodings.compress + + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + () ⇒ new PushStage[ByteString, ByteString] { + def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = + ctx.push(elem ++ ByteString("compressed")) + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DeflateSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DeflateSpec.scala new file mode 100644 index 0000000000..798b9320c0 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/DeflateSpec.scala @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.util.ByteString + +import java.io.{ InputStream, OutputStream } +import java.util.zip._ + +class DeflateSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = Deflate + + protected def newDecodedInputStream(underlying: InputStream): InputStream = + new InflaterInputStream(underlying) + + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = + new DeflaterOutputStream(underlying) + + override def extraTests(): Unit = { + "throw early if header is corrupt" in { + (the[RuntimeException] thrownBy { + ourDecode(ByteString(0, 1, 2, 3, 4)) + }).getCause should be(a[DataFormatException]) + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/EncoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/EncoderSpec.scala new file mode 100644 index 0000000000..bcbd4bd5d2 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/EncoderSpec.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.util.ByteString +import org.scalatest.WordSpec +import akka.http.scaladsl.model._ +import headers._ +import HttpMethods.POST +import scala.concurrent.duration._ +import akka.http.impl.util._ + +class EncoderSpec extends WordSpec with CodecSpecSupport { + + "An Encoder" should { + "not transform the message if messageFilter returns false" in { + val request = HttpRequest(POST, entity = HttpEntity(smallText.getBytes("UTF8"))) + DummyEncoder.encode(request) shouldEqual request + } + "correctly transform the HttpMessage if messageFilter returns true" in { + val request = HttpRequest(POST, entity = HttpEntity(smallText)) + val encoded = DummyEncoder.encode(request) + encoded.headers shouldEqual List(`Content-Encoding`(DummyEncoder.encoding)) + encoded.entity.toStrict(1.second).awaitResult(1.second) shouldEqual HttpEntity(dummyCompress(smallText)) + } + } + + def dummyCompress(s: String): String = dummyCompress(ByteString(s, "UTF8")).utf8String + def dummyCompress(bytes: ByteString): ByteString = DummyCompressor.compressAndFinish(bytes) + + case object DummyEncoder extends Encoder { + val messageFilter = Encoder.DefaultFilter + val encoding = HttpEncodings.compress + def newCompressor = DummyCompressor + } + + case object DummyCompressor extends Compressor { + def compress(input: ByteString) = input ++ ByteString("compressed") + def flush() = ByteString.empty + def finish() = ByteString.empty + + def compressAndFlush(input: ByteString): ByteString = compress(input) + def compressAndFinish(input: ByteString): ByteString = compress(input) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala new file mode 100644 index 0000000000..f95dad83d9 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/GzipSpec.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.http.impl.util._ + +import java.io.{ InputStream, OutputStream } +import java.util.zip.{ ZipException, GZIPInputStream, GZIPOutputStream } + +import akka.util.ByteString + +class GzipSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = Gzip + + protected def newDecodedInputStream(underlying: InputStream): InputStream = + new GZIPInputStream(underlying) + + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = + new GZIPOutputStream(underlying) + + override def extraTests(): Unit = { + "decode concatenated compressions" in { + pending // FIXME: unbreak + ourDecode(Seq(encode("Hello, "), encode("dear "), encode("User!")).join) should readAs("Hello, dear User!") + } + "provide a better compression ratio than the standard Gzip/Gunzip streams" in { + ourEncode(largeTextBytes).length should be < streamEncode(largeTextBytes).length + } + "throw an error on truncated input" in { + pending // FIXME: unbreak + val ex = the[RuntimeException] thrownBy ourDecode(streamEncode(smallTextBytes).dropRight(5)) + ex.getCause.getMessage should equal("Truncated GZIP stream") + } + "throw an error if compressed data is just missing the trailer at the end" in { + def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8")) + + val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl")) + ex.getCause.getMessage should equal("Truncated GZIP stream") + } + "throw early if header is corrupt" in { + val cause = (the[RuntimeException] thrownBy ourDecode(ByteString(0, 1, 2, 3, 4))).getCause + cause should (be(a[ZipException]) and have message "Not in GZIP format") + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/NoCodingSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/NoCodingSpec.scala new file mode 100644 index 0000000000..e1671a5410 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/coding/NoCodingSpec.scala @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import java.io.{ OutputStream, InputStream } + +class NoCodingSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = NoCoding + + override protected def corruptInputCheck = false + + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = underlying + protected def newDecodedInputStream(underlying: InputStream): InputStream = underlying +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/JsonSupportSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/JsonSupportSpec.scala new file mode 100644 index 0000000000..fef3487c72 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/JsonSupportSpec.scala @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshallers + +import akka.http.scaladsl.marshalling.ToEntityMarshaller +import akka.http.scaladsl.model.{ HttpCharsets, HttpEntity, MediaTypes } +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller +import akka.http.impl.util._ +import org.scalatest.{ Matchers, WordSpec } + +case class Employee(fname: String, name: String, age: Int, id: Long, boardMember: Boolean) { + require(!boardMember || age > 40, "Board members must be older than 40") +} + +object Employee { + val simple = Employee("Frank", "Smith", 42, 12345, false) + val json = """{"fname":"Frank","name":"Smith","age":42,"id":12345,"boardMember":false}""" + + val utf8 = Employee("Fränk", "Smi√", 42, 12345, false) + val utf8json = + """{ + | "fname": "Fränk", + | "name": "Smi√", + | "age": 42, + | "id": 12345, + | "boardMember": false + |}""".stripMargin.getBytes(HttpCharsets.`UTF-8`.nioCharset) + + val illegalEmployeeJson = """{"fname":"Little Boy","name":"Smith","age":7,"id":12345,"boardMember":true}""" +} + +/** Common infrastructure needed for several json support subprojects */ +abstract class JsonSupportSpec extends WordSpec with Matchers with ScalatestRouteTest { + require(getClass.getSimpleName.endsWith("Spec")) + // assuming that the classname ends with "Spec" + def name: String = getClass.getSimpleName.dropRight(4) + implicit def marshaller: ToEntityMarshaller[Employee] + implicit def unmarshaller: FromEntityUnmarshaller[Employee] + + "The " + name should { + "provide unmarshalling support for a case class" in { + HttpEntity(MediaTypes.`application/json`, Employee.json) should unmarshalToValue(Employee.simple) + } + "provide marshalling support for a case class" in { + val marshalled = marshal(Employee.simple) + + marshalled.data.utf8String shouldEqual + """{ + | "age": 42, + | "boardMember": false, + | "fname": "Frank", + | "id": 12345, + | "name": "Smith" + |}""".stripMarginWithNewline("\n") + } + "use UTF-8 as the default charset for JSON source decoding" in { + HttpEntity(MediaTypes.`application/json`, Employee.utf8json) should unmarshalToValue(Employee.utf8) + } + "provide proper error messages for requirement errors" in { + val result = unmarshal(HttpEntity(MediaTypes.`application/json`, Employee.illegalEmployeeJson)) + + result.isFailure shouldEqual true + val ex = result.failed.get + ex.getMessage shouldEqual "requirement failed: Board members must be older than 40" + } + } +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/sprayjson/SprayJsonSupportSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/sprayjson/SprayJsonSupportSpec.scala new file mode 100644 index 0000000000..6a98de2252 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/sprayjson/SprayJsonSupportSpec.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshallers.sprayjson + +import java.lang.StringBuilder + +import akka.http.scaladsl.marshallers.{ JsonSupportSpec, Employee } +import akka.http.scaladsl.marshalling.ToEntityMarshaller +import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller +import spray.json.{ JsValue, PrettyPrinter, JsonPrinter, DefaultJsonProtocol } + +import scala.collection.immutable.ListMap + +class SprayJsonSupportSpec extends JsonSupportSpec { + object EmployeeJsonProtocol extends DefaultJsonProtocol { + implicit val employeeFormat = jsonFormat5(Employee.apply) + } + import EmployeeJsonProtocol._ + + implicit val orderedFieldPrint: JsonPrinter = new PrettyPrinter { + override protected def printObject(members: Map[String, JsValue], sb: StringBuilder, indent: Int): Unit = + super.printObject(ListMap(members.toSeq.sortBy(_._1): _*), sb, indent) + } + + implicit def marshaller: ToEntityMarshaller[Employee] = SprayJsonSupport.sprayJsonMarshaller[Employee] + implicit def unmarshaller: FromEntityUnmarshaller[Employee] = SprayJsonSupport.sprayJsonUnmarshaller[Employee] +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala new file mode 100644 index 0000000000..8592390e76 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshallers.xml + +import scala.xml.NodeSeq +import org.scalatest.{ Matchers, WordSpec } +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.http.scaladsl.unmarshalling.{ Unmarshaller, Unmarshal } +import akka.http.scaladsl.model._ +import HttpCharsets._ +import MediaTypes._ + +class ScalaXmlSupportSpec extends WordSpec with Matchers with ScalatestRouteTest { + import ScalaXmlSupport._ + + "ScalaXmlSupport" should { + "NodeSeqMarshaller should marshal xml snippets to `text/xml` content in UTF-8" in { + marshal(Ha“llo) shouldEqual + HttpEntity(ContentType(`text/xml`, `UTF-8`), "Ha“llo") + } + "nodeSeqUnmarshaller should unmarshal `text/xml` content in UTF-8 to NodeSeqs" in { + Unmarshal(HttpEntity(`text/xml`, "Hällö")).to[NodeSeq].map(_.text) should evaluateTo("Hällö") + } + "nodeSeqUnmarshaller should reject `application/octet-stream`" in { + Unmarshal(HttpEntity(`application/octet-stream`, "Hällö")).to[NodeSeq].map(_.text) should + haveFailedWith(Unmarshaller.UnsupportedContentTypeException(nodeSeqContentTypeRanges: _*)) + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/ContentNegotiationSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/ContentNegotiationSpec.scala new file mode 100644 index 0000000000..885b1d81af --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/ContentNegotiationSpec.scala @@ -0,0 +1,138 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.Await +import scala.concurrent.duration._ +import org.scalatest.{ Matchers, FreeSpec } +import akka.http.scaladsl.util.FastFuture._ +import akka.http.scaladsl.model._ +import MediaTypes._ +import HttpCharsets._ + +class ContentNegotiationSpec extends FreeSpec with Matchers { + + "Content Negotiation should work properly for requests with header(s)" - { + + "(without headers)" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-16`) should select(`text/plain`, `UTF-16`) + } + + "Accept: */*" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-16`) should select(`text/plain`, `UTF-16`) + } + + "Accept: */*;q=.8" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-16`) should select(`text/plain`, `UTF-16`) + } + + "Accept: text/*" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/xml` withCharset `UTF-16`) should select(`text/xml`, `UTF-16`) + accept(`audio/ogg`) should reject + } + + "Accept: text/*;q=.8" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/xml` withCharset `UTF-16`) should select(`text/xml`, `UTF-16`) + accept(`audio/ogg`) should reject + } + + "Accept: text/*;q=0" in test { accept ⇒ + accept(`text/plain`) should reject + accept(`text/xml` withCharset `UTF-16`) should reject + accept(`audio/ogg`) should reject + } + + "Accept-Charset: UTF-16" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-16`) + accept(`text/plain` withCharset `UTF-8`) should reject + } + + "Accept-Charset: UTF-16, UTF-8" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-16`) should select(`text/plain`, `UTF-16`) + } + + "Accept-Charset: UTF-8;q=.2, UTF-16" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-16`) + accept(`text/plain` withCharset `UTF-8`) should select(`text/plain`, `UTF-8`) + } + + "Accept-Charset: UTF-8;q=.2" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `ISO-8859-1`) + accept(`text/plain` withCharset `UTF-8`) should select(`text/plain`, `UTF-8`) + } + + "Accept-Charset: latin1;q=.1, UTF-8;q=.2" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-8`) should select(`text/plain`, `UTF-8`) + } + + "Accept-Charset: *" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `UTF-8`) + accept(`text/plain` withCharset `UTF-16`) should select(`text/plain`, `UTF-16`) + } + + "Accept-Charset: *;q=0" in test { accept ⇒ + accept(`text/plain`) should reject + accept(`text/plain` withCharset `UTF-16`) should reject + } + + "Accept-Charset: us;q=0.1,*;q=0" in test { accept ⇒ + accept(`text/plain`) should select(`text/plain`, `US-ASCII`) + accept(`text/plain` withCharset `UTF-8`) should reject + } + + "Accept: text/xml, text/html;q=.5" in test { accept ⇒ + accept(`text/plain`) should reject + accept(`text/xml`) should select(`text/xml`, `UTF-8`) + accept(`text/html`) should select(`text/html`, `UTF-8`) + accept(`text/html`, `text/xml`) should select(`text/xml`, `UTF-8`) + accept(`text/xml`, `text/html`) should select(`text/xml`, `UTF-8`) + accept(`text/plain`, `text/xml`) should select(`text/xml`, `UTF-8`) + accept(`text/plain`, `text/html`) should select(`text/html`, `UTF-8`) + } + + """Accept: text/html, text/plain;q=0.8, application/*;q=.5, *;q= .2 + Accept-Charset: UTF-16""" in test { accept ⇒ + accept(`text/plain`, `text/html`, `audio/ogg`) should select(`text/html`, `UTF-16`) + accept(`text/plain`, `text/html` withCharset `UTF-8`, `audio/ogg`) should select(`text/plain`, `UTF-16`) + accept(`audio/ogg`, `application/javascript`, `text/plain` withCharset `UTF-8`) should select(`application/javascript`, `UTF-16`) + accept(`image/gif`, `application/javascript`) should select(`application/javascript`, `UTF-16`) + accept(`image/gif`, `audio/ogg`) should select(`image/gif`, `UTF-16`) + } + } + + def test[U](body: ((ContentType*) ⇒ Option[ContentType]) ⇒ U): String ⇒ U = { example ⇒ + val headers = + if (example != "(without headers)") { + example.split('\n').toList map { rawHeader ⇒ + val Array(name, value) = rawHeader.split(':') + HttpHeader.parse(name.trim, value) match { + case HttpHeader.ParsingResult.Ok(header, Nil) ⇒ header + case result ⇒ fail(result.errors.head.formatPretty) + } + } + } else Nil + val request = HttpRequest(headers = headers) + body { contentTypes ⇒ + import scala.concurrent.ExecutionContext.Implicits.global + implicit val marshallers = contentTypes map { + case ct @ ContentType(mt, Some(cs)) ⇒ Marshaller.withFixedCharset(mt, cs)((s: String) ⇒ HttpEntity(ct, s)) + case ContentType(mt, None) ⇒ Marshaller.withOpenCharset(mt)((s: String, cs) ⇒ HttpEntity(ContentType(mt, cs), s)) + } + Await.result(Marshal("foo").toResponseFor(request) + .fast.map(response ⇒ Some(response.entity.contentType)) + .fast.recover { case _: Marshal.UnacceptableResponseContentTypeException ⇒ None }, 1.second) + } + } + + def reject = equal(None) + def select(mediaType: MediaType, charset: HttpCharset) = equal(Some(ContentType(mediaType, charset))) +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/MarshallingSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/MarshallingSpec.scala new file mode 100644 index 0000000000..505f7babc8 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshalling/MarshallingSpec.scala @@ -0,0 +1,149 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import akka.http.scaladsl.testkit.MarshallingTestUtils +import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._ + +import scala.collection.immutable.ListMap +import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers } +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl.Source +import akka.http.impl.util._ +import akka.http.scaladsl.model._ +import headers._ +import HttpCharsets._ +import MediaTypes._ + +class MarshallingSpec extends FreeSpec with Matchers with BeforeAndAfterAll with MultipartMarshallers with MarshallingTestUtils { + implicit val system = ActorSystem(getClass.getSimpleName) + implicit val materializer = ActorFlowMaterializer() + import system.dispatcher + + "The PredefinedToEntityMarshallers." - { + "StringMarshaller should marshal strings to `text/plain` content in UTF-8" in { + marshal("Ha“llo") shouldEqual HttpEntity("Ha“llo") + } + "CharArrayMarshaller should marshal char arrays to `text/plain` content in UTF-8" in { + marshal("Ha“llo".toCharArray) shouldEqual HttpEntity("Ha“llo") + } + "FormDataMarshaller should marshal FormData instances to application/x-www-form-urlencoded content" in { + marshal(FormData(Map("name" -> "Bob", "pass" -> "hällo", "admin" -> ""))) shouldEqual + HttpEntity(ContentType(`application/x-www-form-urlencoded`, `UTF-8`), "name=Bob&pass=h%C3%A4llo&admin=") + } + } + + "The GenericMarshallers." - { + "optionMarshaller should enable marshalling of Option[T]" in { + + marshal(Some("Ha“llo")) shouldEqual HttpEntity("Ha“llo") + marshal(None: Option[String]) shouldEqual HttpEntity.Empty + } + "eitherMarshaller should enable marshalling of Either[A, B]" in { + marshal[Either[Array[Char], String]](Right("right")) shouldEqual HttpEntity("right") + marshal[Either[Array[Char], String]](Left("left".toCharArray)) shouldEqual HttpEntity("left") + } + } + + "The MultipartMarshallers." - { + "multipartMarshaller should correctly marshal multipart content with" - { + "one empty part" in { + marshal(Multipart.General(`multipart/mixed`, Multipart.General.BodyPart.Strict(""))) shouldEqual HttpEntity( + contentType = ContentType(`multipart/mixed` withBoundary randomBoundary), + string = s"""--$randomBoundary + |Content-Type: text/plain; charset=UTF-8 + | + | + |--$randomBoundary--""".stripMarginWithNewline("\r\n")) + } + "one non-empty part" in { + marshal(Multipart.General(`multipart/alternative`, Multipart.General.BodyPart.Strict( + entity = HttpEntity(ContentType(`text/plain`, `UTF-8`), "test@there.com"), + headers = `Content-Disposition`(ContentDispositionTypes.`form-data`, Map("name" -> "email")) :: Nil))) shouldEqual + HttpEntity( + contentType = ContentType(`multipart/alternative` withBoundary randomBoundary), + string = s"""--$randomBoundary + |Content-Type: text/plain; charset=UTF-8 + |Content-Disposition: form-data; name=email + | + |test@there.com + |--$randomBoundary--""".stripMarginWithNewline("\r\n")) + } + "two different parts" in { + marshal(Multipart.General(`multipart/related`, + Multipart.General.BodyPart.Strict(HttpEntity(ContentType(`text/plain`, Some(`US-ASCII`)), "first part, with a trailing linebreak\r\n")), + Multipart.General.BodyPart.Strict( + HttpEntity(ContentType(`application/octet-stream`), "filecontent"), + RawHeader("Content-Transfer-Encoding", "binary") :: Nil))) shouldEqual + HttpEntity( + contentType = ContentType(`multipart/related` withBoundary randomBoundary), + string = s"""--$randomBoundary + |Content-Type: text/plain; charset=US-ASCII + | + |first part, with a trailing linebreak + | + |--$randomBoundary + |Content-Type: application/octet-stream + |Content-Transfer-Encoding: binary + | + |filecontent + |--$randomBoundary--""".stripMarginWithNewline("\r\n")) + } + } + + "multipartFormDataMarshaller should correctly marshal 'multipart/form-data' content with" - { + "two fields" in { + marshal(Multipart.FormData(ListMap( + "surname" -> HttpEntity("Mike"), + "age" -> marshal(42)))) shouldEqual + HttpEntity( + contentType = ContentType(`multipart/form-data` withBoundary randomBoundary), + string = s"""--$randomBoundary + |Content-Type: text/plain; charset=UTF-8 + |Content-Disposition: form-data; name=surname + | + |Mike + |--$randomBoundary + |Content-Type: text/xml; charset=UTF-8 + |Content-Disposition: form-data; name=age + | + |42 + |--$randomBoundary--""".stripMarginWithNewline("\r\n")) + } + + "two fields having a custom `Content-Disposition`" in { + marshal(Multipart.FormData(Source(List( + Multipart.FormData.BodyPart("attachment[0]", HttpEntity(`text/csv`, "name,age\r\n\"John Doe\",20\r\n"), + Map("filename" -> "attachment.csv")), + Multipart.FormData.BodyPart("attachment[1]", HttpEntity("naice!".getBytes), + Map("filename" -> "attachment2.csv"), List(RawHeader("Content-Transfer-Encoding", "binary"))))))) shouldEqual + HttpEntity( + contentType = ContentType(`multipart/form-data` withBoundary randomBoundary), + string = s"""--$randomBoundary + |Content-Type: text/csv + |Content-Disposition: form-data; filename=attachment.csv; name="attachment[0]" + | + |name,age + |"John Doe",20 + | + |--$randomBoundary + |Content-Type: application/octet-stream + |Content-Disposition: form-data; filename=attachment2.csv; name="attachment[1]" + |Content-Transfer-Encoding: binary + | + |naice! + |--$randomBoundary--""".stripMarginWithNewline("\r\n")) + } + } + } + + override def afterAll() = system.shutdown() + + protected class FixedRandom extends java.util.Random { + override def nextBytes(array: Array[Byte]): Unit = "my-stable-boundary".getBytes("UTF-8").copyToArray(array) + } + override protected val multipartBoundaryRandom = new FixedRandom // fix for stable value +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala new file mode 100644 index 0000000000..ba8c3e5d42 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import akka.http.scaladsl.model +import model.HttpMethods._ +import model.StatusCodes + +class BasicRouteSpecs extends RoutingSpec { + + "routes created by the concatenation operator '~'" should { + "yield the first sub route if it succeeded" in { + Get() ~> { + get { complete("first") } ~ get { complete("second") } + } ~> check { responseAs[String] shouldEqual "first" } + } + "yield the second sub route if the first did not succeed" in { + Get() ~> { + post { complete("first") } ~ get { complete("second") } + } ~> check { responseAs[String] shouldEqual "second" } + } + "collect rejections from both sub routes" in { + Delete() ~> { + get { completeOk } ~ put { completeOk } + } ~> check { rejections shouldEqual Seq(MethodRejection(GET), MethodRejection(PUT)) } + } + "clear rejections that have already been 'overcome' by previous directives" in { + pending + /*Put() ~> { + put { parameter('yeah) { echoComplete } } ~ + get { completeOk } + } ~> check { rejection shouldEqual MissingQueryParamRejection("yeah") }*/ + } + } + + "Route conjunction" should { + val stringDirective = provide("The cat") + val intDirective = provide(42) + val doubleDirective = provide(23.0) + + val dirStringInt = stringDirective & intDirective + val dirStringIntDouble = dirStringInt & doubleDirective + val dirDoubleStringInt = doubleDirective & dirStringInt + val dirStringIntStringInt = dirStringInt & dirStringInt + + "work for two elements" in { + Get("/abc") ~> { + dirStringInt { (str, i) ⇒ + complete(s"$str ${i + 1}") + } + } ~> check { responseAs[String] shouldEqual "The cat 43" } + } + "work for 2 + 1" in { + Get("/abc") ~> { + dirStringIntDouble { (str, i, d) ⇒ + complete(s"$str ${i + 1} ${d + 0.1}") + } + } ~> check { responseAs[String] shouldEqual "The cat 43 23.1" } + } + "work for 1 + 2" in { + Get("/abc") ~> { + dirDoubleStringInt { (d, str, i) ⇒ + complete(s"$str ${i + 1} ${d + 0.1}") + } + } ~> check { responseAs[String] shouldEqual "The cat 43 23.1" } + } + "work for 2 + 2" in { + Get("/abc") ~> { + dirStringIntStringInt { (str, i, str2, i2) ⇒ + complete(s"$str ${i + i2} $str2") + } + } ~> check { responseAs[String] shouldEqual "The cat 84 The cat" } + } + } + "Route disjunction" should { + "work in the happy case" in { + val route = Route.seal((path("abc") | path("def")) { + completeOk + }) + + Get("/abc") ~> route ~> check { + status shouldEqual StatusCodes.OK + } + Get("/def") ~> route ~> check { + status shouldEqual StatusCodes.OK + } + Get("/ghi") ~> route ~> check { + status shouldEqual StatusCodes.NotFound + } + } + "don't apply alternative if inner route rejects" in { + object MyRejection extends Rejection + val route = (path("abc") | post) { + reject(MyRejection) + } + Get("/abc") ~> route ~> check { + rejection shouldEqual MyRejection + } + } + } + "Case class extraction with Directive.as" should { + "extract one argument" in { + case class MyNumber(i: Int) + + val abcPath = path("abc" / IntNumber).as(MyNumber)(echoComplete) + + Get("/abc/5") ~> abcPath ~> check { + responseAs[String] shouldEqual "MyNumber(5)" + } + } + "extract two arguments" in { + case class Person(name: String, age: Int) + + val personPath = path("person" / Segment / IntNumber).as(Person)(echoComplete) + + Get("/person/john/38") ~> personPath ~> check { + responseAs[String] shouldEqual "Person(john,38)" + } + } + } + "Dynamic execution of inner routes of Directive0" should { + "re-execute inner routes every time" in { + var a = "" + val dynamicRoute = get { a += "x"; complete(a) } + def expect(route: Route, s: String) = Get() ~> route ~> check { responseAs[String] shouldEqual s } + + expect(dynamicRoute, "x") + expect(dynamicRoute, "xx") + expect(dynamicRoute, "xxx") + expect(dynamicRoute, "xxxx") + } + } + + case object MyException extends RuntimeException + "Route sealing" should { + "catch route execution exceptions" in { + Get("/abc") ~> Route.seal { + get { ctx ⇒ + throw MyException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + } + } + "catch route building exceptions" in { + Get("/abc") ~> Route.seal { + get { + throw MyException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + } + } + "convert all rejections to responses" in { + object MyRejection extends Rejection + Get("/abc") ~> Route.seal { + get { + reject(MyRejection) + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + } + } + "always prioritize MethodRejections over AuthorizationFailedRejections" in { + Get("/abc") ~> Route.seal { + post { completeOk } ~ + authorize(false) { completeOk } + } ~> check { + status shouldEqual StatusCodes.MethodNotAllowed + responseAs[String] shouldEqual "HTTP method not allowed, supported methods: POST" + } + + Get("/abc") ~> Route.seal { + authorize(false) { completeOk } ~ + post { completeOk } + } ~> check { status shouldEqual StatusCodes.MethodNotAllowed } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/RoutingSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/RoutingSpec.scala new file mode 100644 index 0000000000..acf41818c1 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/RoutingSpec.scala @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import org.scalatest.{ WordSpec, Suite, Matchers } +import akka.http.scaladsl.model.HttpResponse +import akka.http.scaladsl.testkit.ScalatestRouteTest + +trait GenericRoutingSpec extends Matchers with Directives with ScalatestRouteTest { this: Suite ⇒ + val Ok = HttpResponse() + val completeOk = complete(Ok) + + def echoComplete[T]: T ⇒ Route = { x ⇒ complete(x.toString) } + def echoComplete2[T, U]: (T, U) ⇒ Route = { (x, y) ⇒ complete(s"$x $y") } +} + +abstract class RoutingSpec extends WordSpec with GenericRoutingSpec diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/TestServer.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/TestServer.scala new file mode 100644 index 0000000000..42ff66c297 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/TestServer.scala @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport +import akka.http.scaladsl.server.directives.UserCredentials +import com.typesafe.config.{ ConfigFactory, Config } +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.http.scaladsl.Http + +object TestServer extends App { + val testConf: Config = ConfigFactory.parseString(""" + akka.loglevel = INFO + akka.log-dead-letters = off""") + implicit val system = ActorSystem("ServerTest", testConf) + import system.dispatcher + implicit val materializer = ActorFlowMaterializer() + + import ScalaXmlSupport._ + import Directives._ + + def auth: AuthenticatorPF[String] = { + case p @ UserCredentials.Provided(name) if p.verifySecret(name + "-password") ⇒ name + } + + val bindingFuture = Http().bindAndHandle({ + get { + path("") { + complete(index) + } ~ + path("secure") { + authenticateBasicPF("My very secure site", auth) { user ⇒ + complete(Hello { user }. Access has been granted!) + } + } ~ + path("ping") { + complete("PONG!") + } ~ + path("crash") { + complete(sys.error("BOOM!")) + } + } + }, interface = "localhost", port = 8080) + + println(s"Server online at http://localhost:8080/\nPress RETURN to stop...") + Console.readLine() + + bindingFuture.flatMap(_.unbind()).onComplete(_ ⇒ system.shutdown()) + + lazy val index = + + +

Say hello to akka-http-core!

+

Defined resources:

+ + + +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/BasicDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/BasicDirectivesSpec.scala new file mode 100644 index 0000000000..44f60b1a66 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/BasicDirectivesSpec.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +class BasicDirectivesSpec extends RoutingSpec { + + "The `mapUnmatchedPath` directive" should { + "map the unmatched path" in { + Get("/abc") ~> { + mapUnmatchedPath(_ / "def") { + path("abc" / "def") { completeOk } + } + } ~> check { response shouldEqual Ok } + } + } + + "The `extract` directive" should { + "extract from the RequestContext" in { + Get("/abc") ~> { + extract(_.request.method.value) { + echoComplete + } + } ~> check { responseAs[String] shouldEqual "GET" } + } + } +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CacheConditionDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CacheConditionDirectivesSpec.scala new file mode 100644 index 0000000000..4c7cd631dd --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CacheConditionDirectivesSpec.scala @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import StatusCodes._ +import headers._ + +class CacheConditionDirectivesSpec extends RoutingSpec { + + "the `conditional` directive" should { + val timestamp = DateTime.now - 2000 + val ifUnmodifiedSince = `If-Unmodified-Since`(timestamp) + val ifModifiedSince = `If-Modified-Since`(timestamp) + val tag = EntityTag("fresh") + val responseHeaders = List(ETag(tag), `Last-Modified`(timestamp)) + + def taggedAndTimestamped = conditional(tag, timestamp) { completeOk } + def weak = conditional(tag.copy(weak = true), timestamp) { completeOk } + + "return OK for new resources" in { + Get() ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + } + + "return OK for non-matching resources" in { + Get() ~> `If-None-Match`(EntityTag("old")) ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-Modified-Since`(timestamp - 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(EntityTag("old")) ~> `If-Modified-Since`(timestamp - 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + } + + "ignore If-Modified-Since if If-None-Match is defined" in { + Get() ~> `If-None-Match`(tag) ~> `If-Modified-Since`(timestamp - 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + } + Get() ~> `If-None-Match`(EntityTag("old")) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual OK + } + } + + "return PreconditionFailed for matched but unsafe resources" in { + Put() ~> `If-None-Match`(tag) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual PreconditionFailed + headers shouldEqual Nil + } + } + + "return NotModified for matching resources" in { + Get() ~> `If-None-Match`.`*` ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(tag) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(tag) ~> `If-Modified-Since`(timestamp + 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(tag.copy(weak = true)) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(tag, EntityTag("some"), EntityTag("other")) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + } + + "return NotModified when only one matching header is set" in { + Get() ~> `If-None-Match`.`*` ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> `If-None-Match`(tag) ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + Get() ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (responseHeaders) + } + } + + "return NotModified for matching weak resources" in { + val weakTag = tag.copy(weak = true) + Get() ~> `If-None-Match`(tag) ~> weak ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (List(ETag(weakTag), `Last-Modified`(timestamp))) + } + Get() ~> `If-None-Match`(weakTag) ~> weak ~> check { + status shouldEqual NotModified + headers should contain theSameElementsAs (List(ETag(weakTag), `Last-Modified`(timestamp))) + } + } + + "return normally for matching If-Match/If-Unmodified" in { + Put() ~> `If-Match`.`*` ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + Put() ~> `If-Match`(tag) ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + Put() ~> ifUnmodifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual OK + headers should contain theSameElementsAs (responseHeaders) + } + } + + "return PreconditionFailed for non-matching If-Match/If-Unmodified" in { + Put() ~> `If-Match`(EntityTag("old")) ~> taggedAndTimestamped ~> check { + status shouldEqual PreconditionFailed + headers shouldEqual Nil + } + Put() ~> `If-Unmodified-Since`(timestamp - 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual PreconditionFailed + headers shouldEqual Nil + } + } + + "ignore If-Unmodified-Since if If-Match is defined" in { + Put() ~> `If-Match`(tag) ~> `If-Unmodified-Since`(timestamp - 1000) ~> taggedAndTimestamped ~> check { + status shouldEqual OK + } + Put() ~> `If-Match`(EntityTag("old")) ~> ifModifiedSince ~> taggedAndTimestamped ~> check { + status shouldEqual PreconditionFailed + } + } + + "not filter out a `Range` header if `If-Range` does match the timestamp" in { + Get() ~> `If-Range`(timestamp) ~> Range(ByteRange(0, 10)) ~> { + (conditional(tag, timestamp) & optionalHeaderValueByType[Range]()) { echoComplete } + } ~> check { + status shouldEqual OK + responseAs[String] should startWith("Some") + } + } + + "filter out a `Range` header if `If-Range` doesn't match the timestamp" in { + Get() ~> `If-Range`(timestamp - 1000) ~> Range(ByteRange(0, 10)) ~> { + (conditional(tag, timestamp) & optionalHeaderValueByType[Range]()) { echoComplete } + } ~> check { + status shouldEqual OK + responseAs[String] shouldEqual "None" + } + } + + "not filter out a `Range` header if `If-Range` does match the ETag" in { + Get() ~> `If-Range`(tag) ~> Range(ByteRange(0, 10)) ~> { + (conditional(tag, timestamp) & optionalHeaderValueByType[Range]()) { echoComplete } + } ~> check { + status shouldEqual OK + responseAs[String] should startWith("Some") + } + } + + "filter out a `Range` header if `If-Range` doesn't match the ETag" in { + Get() ~> `If-Range`(EntityTag("other")) ~> Range(ByteRange(0, 10)) ~> { + (conditional(tag, timestamp) & optionalHeaderValueByType[Range]()) { echoComplete } + } ~> check { + status shouldEqual OK + responseAs[String] shouldEqual "None" + } + } + } + +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CodingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CodingDirectivesSpec.scala new file mode 100644 index 0000000000..f6cd19e114 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CodingDirectivesSpec.scala @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import org.scalatest.matchers.Matcher +import akka.util.ByteString +import akka.stream.scaladsl.Source +import akka.http.impl.util._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.coding._ +import headers._ +import HttpEntity.{ ChunkStreamPart, Chunk } +import HttpCharsets._ +import HttpEncodings._ +import MediaTypes._ +import StatusCodes._ + +import scala.concurrent.duration._ + +class CodingDirectivesSpec extends RoutingSpec { + + val echoRequestContent: Route = { ctx ⇒ ctx.complete(ctx.request.entity.dataBytes.utf8String) } + + val yeah = complete("Yeah!") + lazy val yeahGzipped = compress("Yeah!", Gzip) + lazy val yeahDeflated = compress("Yeah!", Deflate) + + lazy val helloGzipped = compress("Hello", Gzip) + lazy val helloDeflated = compress("Hello", Deflate) + + "the NoEncoding decoder" should { + "decode the request content if it has encoding 'identity'" in { + Post("/", "yes") ~> `Content-Encoding`(identity) ~> { + decodeRequestWith(NoCoding) { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "yes" } + } + "reject requests with content encoded with 'deflate'" in { + Post("/", "yes") ~> `Content-Encoding`(deflate) ~> { + decodeRequestWith(NoCoding) { echoRequestContent } + } ~> check { rejection shouldEqual UnsupportedRequestEncodingRejection(identity) } + } + "decode the request content if no Content-Encoding header is present" in { + Post("/", "yes") ~> decodeRequestWith(NoCoding) { echoRequestContent } ~> check { responseAs[String] shouldEqual "yes" } + } + "leave request without content unchanged" in { + Post() ~> decodeRequestWith(Gzip) { completeOk } ~> check { response shouldEqual Ok } + } + } + + "the Gzip decoder" should { + "decode the request content if it has encoding 'gzip'" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> { + decodeRequestWith(Gzip) { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "Hello" } + } + "reject the request content if it has encoding 'gzip' but is corrupt" in { + Post("/", fromHexDump("000102")) ~> `Content-Encoding`(gzip) ~> { + decodeRequestWith(Gzip) { echoRequestContent } + } ~> check { + status shouldEqual BadRequest + responseAs[String] shouldEqual "The request's encoding is corrupt" + } + } + "reject truncated gzip request content" in { + Post("/", helloGzipped.dropRight(2)) ~> `Content-Encoding`(gzip) ~> { + decodeRequestWith(Gzip) { echoRequestContent } + } ~> check { + status shouldEqual BadRequest + responseAs[String] shouldEqual "The request's encoding is corrupt" + } + } + "reject requests with content encoded with 'deflate'" in { + Post("/", "Hello") ~> `Content-Encoding`(deflate) ~> { + decodeRequestWith(Gzip) { completeOk } + } ~> check { rejection shouldEqual UnsupportedRequestEncodingRejection(gzip) } + } + "reject requests without Content-Encoding header" in { + Post("/", "Hello") ~> { + decodeRequestWith(Gzip) { completeOk } + } ~> check { rejection shouldEqual UnsupportedRequestEncodingRejection(gzip) } + } + "leave request without content unchanged" in { + Post() ~> { + decodeRequestWith(Gzip) { completeOk } + } ~> check { response shouldEqual Ok } + } + } + + "a (decodeRequestWith(Gzip) | decodeRequestWith(NoEncoding)) compound directive" should { + lazy val decodeWithGzipOrNoEncoding = decodeRequestWith(Gzip) | decodeRequestWith(NoCoding) + "decode the request content if it has encoding 'gzip'" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> { + decodeWithGzipOrNoEncoding { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "Hello" } + } + "decode the request content if it has encoding 'identity'" in { + Post("/", "yes") ~> `Content-Encoding`(identity) ~> { + decodeWithGzipOrNoEncoding { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "yes" } + } + "decode the request content if no Content-Encoding header is present" in { + Post("/", "yes") ~> decodeWithGzipOrNoEncoding { echoRequestContent } ~> check { responseAs[String] shouldEqual "yes" } + } + "reject requests with content encoded with 'deflate'" in { + Post("/", "yes") ~> `Content-Encoding`(deflate) ~> { + decodeWithGzipOrNoEncoding { echoRequestContent } + } ~> check { + rejections shouldEqual Seq( + UnsupportedRequestEncodingRejection(gzip), + UnsupportedRequestEncodingRejection(identity)) + } + } + } + + "the Gzip encoder" should { + "encode the response content with GZIP if the client accepts it with a dedicated Accept-Encoding header" in { + Post() ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Gzip) { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "encode the response content with GZIP if the request has no Accept-Encoding header" in { + Post() ~> { + encodeResponseWith(Gzip) { yeah } + } ~> check { strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) } + } + "reject the request if the client does not accept GZIP encoding" in { + Post() ~> `Accept-Encoding`(identity) ~> { + encodeResponseWith(Gzip) { completeOk } + } ~> check { rejection shouldEqual UnacceptedResponseEncodingRejection(gzip) } + } + "leave responses without content unchanged" in { + Post() ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Gzip) { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "leave responses with an already set Content-Encoding header unchanged" in { + Post() ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Gzip) { + RespondWithDirectives.respondWithHeader(`Content-Encoding`(identity)) { completeOk } + } + } ~> check { response shouldEqual Ok.withHeaders(`Content-Encoding`(identity)) } + } + "correctly encode the chunk stream produced by a chunked response" in { + val text = "This is a somewhat lengthy text that is being chunked by the autochunk directive!" + val textChunks = + () ⇒ text.grouped(8).map { chars ⇒ + Chunk(chars.mkString): ChunkStreamPart + } + val chunkedTextEntity = HttpEntity.Chunked(MediaTypes.`text/plain`, Source(textChunks)) + + Post() ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Gzip) { + complete(chunkedTextEntity) + } + } ~> check { + response should haveContentEncoding(gzip) + chunks.size shouldEqual (11 + 1) // 11 regular + the last one + val bytes = chunks.foldLeft(ByteString.empty)(_ ++ _.data) + Gzip.decode(bytes).awaitResult(1.second) should readAs(text) + } + } + } + + "the encodeResponseWith(NoEncoding) directive" should { + "produce a response if no Accept-Encoding is present in the request" in { + Post() ~> encodeResponseWith(NoCoding) { completeOk } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "produce a not encoded response if the client only accepts non matching encodings" in { + Post() ~> `Accept-Encoding`(gzip, identity) ~> { + encodeResponseWith(NoCoding) { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + + Post() ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Deflate, NoCoding) { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "reject the request if the request has an 'Accept-Encoding: identity; q=0' header" in { + Post() ~> `Accept-Encoding`(identity.withQValue(0f)) ~> { + encodeResponseWith(NoCoding) { completeOk } + } ~> check { rejection shouldEqual UnacceptedResponseEncodingRejection(identity) } + } + } + + "a (encodeResponse(Gzip) | encodeResponse(NoEncoding)) compound directive" should { + lazy val encodeGzipOrIdentity = encodeResponseWith(Gzip) | encodeResponseWith(NoCoding) + "produce a not encoded response if the request has no Accept-Encoding header" in { + Post() ~> { + encodeGzipOrIdentity { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "produce a GZIP encoded response if the request has an `Accept-Encoding: deflate;q=0.5, gzip` header" in { + Post() ~> `Accept-Encoding`(deflate.withQValue(.5f), gzip) ~> { + encodeGzipOrIdentity { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "produce a non-encoded response if the request has an `Accept-Encoding: identity` header" in { + Post() ~> `Accept-Encoding`(identity) ~> { + encodeGzipOrIdentity { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "produce a non-encoded response if the request has an `Accept-Encoding: deflate` header" in { + Post() ~> `Accept-Encoding`(deflate) ~> { + encodeGzipOrIdentity { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + } + + "the encodeResponse directive" should { + "produce a non-encoded response if the request has no Accept-Encoding header" in { + Get("/") ~> { + encodeResponse { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "produce a GZIP encoded response if the request has an `Accept-Encoding: gzip, deflate` header" in { + Get("/") ~> `Accept-Encoding`(gzip, deflate) ~> { + encodeResponse { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "produce a Deflate encoded response if the request has an `Accept-Encoding: deflate` header" in { + Get("/") ~> `Accept-Encoding`(deflate) ~> { + encodeResponse { yeah } + } ~> check { + response should haveContentEncoding(deflate) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahDeflated) + } + } + } + + "the encodeResponseWith directive" should { + "produce a response encoded with the specified Encoder if the request has a matching Accept-Encoding header" in { + Get("/") ~> `Accept-Encoding`(gzip) ~> { + encodeResponseWith(Gzip) { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "produce a response encoded with one of the specified Encoders if the request has a matching Accept-Encoding header" in { + Get("/") ~> `Accept-Encoding`(deflate) ~> { + encodeResponseWith(Gzip, Deflate) { yeah } + } ~> check { + response should haveContentEncoding(deflate) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahDeflated) + } + } + "produce a response encoded with the first of the specified Encoders if the request has no Accept-Encoding header" in { + Get("/") ~> { + encodeResponseWith(Gzip, Deflate) { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "produce a response with no encoding if the request has an empty Accept-Encoding header" in { + Get("/") ~> `Accept-Encoding`() ~> { + encodeResponseWith(Gzip, Deflate, NoCoding) { completeOk } + } ~> check { + response shouldEqual Ok + response should haveNoContentEncoding + } + } + "negotiate the correct content encoding" in { + Get("/") ~> `Accept-Encoding`(identity.withQValue(.5f), deflate.withQValue(0f), gzip) ~> { + encodeResponseWith(NoCoding, Deflate, Gzip) { yeah } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), yeahGzipped) + } + } + "reject the request if it has an Accept-Encoding header with an encoding that doesn't match" in { + Get("/") ~> `Accept-Encoding`(deflate) ~> { + encodeResponseWith(Gzip) { yeah } + } ~> check { + rejection shouldEqual UnacceptedResponseEncodingRejection(gzip) + } + } + "reject the request if it has an Accept-Encoding header with an encoding that matches but is blacklisted" in { + Get("/") ~> `Accept-Encoding`(gzip.withQValue(0f)) ~> { + encodeResponseWith(Gzip) { yeah } + } ~> check { + rejection shouldEqual UnacceptedResponseEncodingRejection(gzip) + } + } + } + + "the decodeRequest directive" should { + "decode the request content if it has a `Content-Encoding: gzip` header and the content is gzip encoded" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> { + decodeRequest { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "Hello" } + } + "decode the request content if it has a `Content-Encoding: deflate` header and the content is deflate encoded" in { + Post("/", helloDeflated) ~> `Content-Encoding`(deflate) ~> { + decodeRequest { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "Hello" } + } + "decode the request content if it has a `Content-Encoding: identity` header and the content is not encoded" in { + Post("/", "yes") ~> `Content-Encoding`(identity) ~> { + decodeRequest { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "yes" } + } + "decode the request content using NoEncoding if no Content-Encoding header is present" in { + Post("/", "yes") ~> decodeRequest { echoRequestContent } ~> check { responseAs[String] shouldEqual "yes" } + } + "reject the request if it has a `Content-Encoding: deflate` header but the request is encoded with Gzip" in { + Post("/", helloGzipped) ~> `Content-Encoding`(deflate) ~> + decodeRequest { echoRequestContent } ~> check { + status shouldEqual BadRequest + responseAs[String] shouldEqual "The request's encoding is corrupt" + } + } + } + + "the decodeRequestWith directive" should { + "decode the request content if its `Content-Encoding` header matches the specified encoder" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> { + decodeRequestWith(Gzip) { echoRequestContent } + } ~> check { responseAs[String] shouldEqual "Hello" } + } + "reject the request if its `Content-Encoding` header doesn't match the specified encoder" in { + Post("/", helloGzipped) ~> `Content-Encoding`(deflate) ~> { + decodeRequestWith(Gzip) { echoRequestContent } + } ~> check { + rejection shouldEqual UnsupportedRequestEncodingRejection(gzip) + } + } + "reject the request when decodeing with GZIP and no Content-Encoding header is present" in { + Post("/", "yes") ~> decodeRequestWith(Gzip) { echoRequestContent } ~> check { + rejection shouldEqual UnsupportedRequestEncodingRejection(gzip) + } + } + } + + "the (decodeRequest & encodeResponse) compound directive" should { + lazy val decodeEncode = decodeRequest & encodeResponse + "decode a GZIP encoded request and produce a none encoded response if the request has no Accept-Encoding header" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> { + decodeEncode { echoRequestContent } + } ~> check { + response should haveNoContentEncoding + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), "Hello") + } + } + "decode a GZIP encoded request and produce a Deflate encoded response if the request has an `Accept-Encoding: deflate` header" in { + Post("/", helloGzipped) ~> `Content-Encoding`(gzip) ~> `Accept-Encoding`(deflate) ~> { + decodeEncode { echoRequestContent } + } ~> check { + response should haveContentEncoding(deflate) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), helloDeflated) + } + } + "decode an unencoded request and produce a GZIP encoded response if the request has an `Accept-Encoding: gzip` header" in { + Post("/", "Hello") ~> `Accept-Encoding`(gzip) ~> { + decodeEncode { echoRequestContent } + } ~> check { + response should haveContentEncoding(gzip) + strictify(responseEntity) shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), helloGzipped) + } + } + } + + def compress(input: String, encoder: Encoder): ByteString = { + val compressor = encoder.newCompressor + compressor.compressAndFlush(ByteString(input)) ++ compressor.finish() + } + + def hexDump(bytes: Array[Byte]) = bytes.map("%02x" format _).mkString + def fromHexDump(dump: String) = dump.grouped(2).toArray.map(chars ⇒ Integer.parseInt(new String(chars), 16).toByte) + + def haveNoContentEncoding: Matcher[HttpResponse] = be(None) compose { (_: HttpResponse).header[`Content-Encoding`] } + def haveContentEncoding(encoding: HttpEncoding): Matcher[HttpResponse] = + be(Some(`Content-Encoding`(encoding))) compose { (_: HttpResponse).header[`Content-Encoding`] } + + def readAs(string: String, charset: String = "UTF8") = be(string) compose { (_: ByteString).decodeString(charset) } + + def strictify(entity: HttpEntity) = entity.toStrict(1.second).awaitResult(1.second) +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CookieDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CookieDirectivesSpec.scala new file mode 100644 index 0000000000..dab2344529 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CookieDirectivesSpec.scala @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import StatusCodes.OK +import headers._ + +class CookieDirectivesSpec extends RoutingSpec { + + val deletedTimeStamp = DateTime.fromIsoDateTimeString("1800-01-01T00:00:00") + + "The 'cookie' directive" should { + "extract the respectively named cookie" in { + Get() ~> addHeader(Cookie(HttpCookie("fancy", "pants"))) ~> { + cookie("fancy") { echoComplete } + } ~> check { responseAs[String] shouldEqual "fancy=pants" } + } + "reject the request if the cookie is not present" in { + Get() ~> { + cookie("fancy") { echoComplete } + } ~> check { rejection shouldEqual MissingCookieRejection("fancy") } + } + "properly pass through inner rejections" in { + Get() ~> addHeader(Cookie(HttpCookie("fancy", "pants"))) ~> { + cookie("fancy") { c ⇒ reject(ValidationRejection("Dont like " + c.content)) } + } ~> check { rejection shouldEqual ValidationRejection("Dont like pants") } + } + } + + "The 'deleteCookie' directive" should { + "add a respective Set-Cookie headers to successful responses" in { + Get() ~> { + deleteCookie("myCookie", "test.com") { completeOk } + } ~> check { + status shouldEqual OK + header[`Set-Cookie`] shouldEqual Some(`Set-Cookie`(HttpCookie("myCookie", "deleted", expires = deletedTimeStamp, + domain = Some("test.com")))) + } + } + + "support deleting multiple cookies at a time" in { + Get() ~> { + deleteCookie(HttpCookie("myCookie", "test.com"), HttpCookie("myCookie2", "foobar.com")) { completeOk } + } ~> check { + status shouldEqual OK + headers.collect { case `Set-Cookie`(x) ⇒ x } shouldEqual List( + HttpCookie("myCookie", "deleted", expires = deletedTimeStamp), + HttpCookie("myCookie2", "deleted", expires = deletedTimeStamp)) + } + } + } + + "The 'optionalCookie' directive" should { + "produce a `Some(cookie)` extraction if the cookie is present" in { + Get() ~> Cookie(HttpCookie("abc", "123")) ~> { + optionalCookie("abc") { echoComplete } + } ~> check { responseAs[String] shouldEqual "Some(abc=123)" } + } + "produce a `None` extraction if the cookie is not present" in { + Get() ~> optionalCookie("abc") { echoComplete } ~> check { responseAs[String] shouldEqual "None" } + } + "let rejections from its inner route pass through" in { + Get() ~> { + optionalCookie("test-cookie") { _ ⇒ + validate(false, "ouch") { completeOk } + } + } ~> check { rejection shouldEqual ValidationRejection("ouch") } + } + } + + "The 'setCookie' directive" should { + "add a respective Set-Cookie headers to successful responses" in { + Get() ~> { + setCookie(HttpCookie("myCookie", "test.com")) { completeOk } + } ~> check { + status shouldEqual OK + header[`Set-Cookie`] shouldEqual Some(`Set-Cookie`(HttpCookie("myCookie", "test.com"))) + } + } + + "support setting multiple cookies at a time" in { + Get() ~> { + setCookie(HttpCookie("myCookie", "test.com"), HttpCookie("myCookie2", "foobar.com")) { completeOk } + } ~> check { + status shouldEqual OK + headers.collect { case `Set-Cookie`(x) ⇒ x } shouldEqual List( + HttpCookie("myCookie", "test.com"), HttpCookie("myCookie2", "foobar.com")) + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/DebuggingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/DebuggingDirectivesSpec.scala new file mode 100644 index 0000000000..27d174e322 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/DebuggingDirectivesSpec.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.event.LoggingAdapter +import akka.http.impl.util._ + +class DebuggingDirectivesSpec extends RoutingSpec { + var debugMsg = "" + + def resetDebugMsg(): Unit = { debugMsg = "" } + + val log = new LoggingAdapter { + def isErrorEnabled = true + def isWarningEnabled = true + def isInfoEnabled = true + def isDebugEnabled = true + + def notifyError(message: String): Unit = {} + def notifyError(cause: Throwable, message: String): Unit = {} + def notifyWarning(message: String): Unit = {} + def notifyInfo(message: String): Unit = {} + def notifyDebug(message: String): Unit = { debugMsg += message + '\n' } + } + + "The 'logRequest' directive" should { + "produce a proper log message for incoming requests" in { + val route = + withLog(log)( + logRequest("1")( + completeOk)) + + resetDebugMsg() + Get("/hello") ~> route ~> check { + response shouldEqual Ok + debugMsg shouldEqual "1: HttpRequest(HttpMethod(GET),http://example.com/hello,List(),HttpEntity.Strict(none/none,ByteString()),HttpProtocol(HTTP/1.1))\n" + } + } + } + + "The 'logResponse' directive" should { + "produce a proper log message for outgoing responses" in { + val route = + withLog(log)( + logResult("2")( + completeOk)) + + resetDebugMsg() + Get("/hello") ~> route ~> check { + response shouldEqual Ok + debugMsg shouldEqual "2: Complete(HttpResponse(200 OK,List(),HttpEntity.Strict(none/none,ByteString()),HttpProtocol(HTTP/1.1)))\n" + } + } + } + + "The 'logRequestResponse' directive" should { + "produce proper log messages for outgoing responses, thereby showing the corresponding request" in { + val route = + withLog(log)( + logRequestResult("3")( + completeOk)) + + resetDebugMsg() + Get("/hello") ~> route ~> check { + response shouldEqual Ok + debugMsg shouldEqual """|3: Response for + | Request : HttpRequest(HttpMethod(GET),http://example.com/hello,List(),HttpEntity.Strict(none/none,ByteString()),HttpProtocol(HTTP/1.1)) + | Response: Complete(HttpResponse(200 OK,List(),HttpEntity.Strict(none/none,ByteString()),HttpProtocol(HTTP/1.1))) + |""".stripMarginWithNewline("\n") + } + } + } + +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala new file mode 100644 index 0000000000..a9e82279f4 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.{ MediaTypes, MediaRanges, StatusCodes } +import akka.http.scaladsl.model.headers._ + +import scala.concurrent.Future + +class ExecutionDirectivesSpec extends RoutingSpec { + object MyException extends RuntimeException + val handler = + ExceptionHandler { + case MyException ⇒ complete(500, "Pling! Plong! Something went wrong!!!") + } + + "The `handleExceptions` directive" should { + "handle an exception strictly thrown in the inner route with the supplied exception handler" in { + exceptionShouldBeHandled { + handleExceptions(handler) { ctx ⇒ + throw MyException + } + } + } + "handle an Future.failed RouteResult with the supplied exception handler" in { + exceptionShouldBeHandled { + handleExceptions(handler) { ctx ⇒ + Future.failed(MyException) + } + } + } + "handle an eventually failed Future[RouteResult] with the supplied exception handler" in { + exceptionShouldBeHandled { + handleExceptions(handler) { ctx ⇒ + Future { + Thread.sleep(100) + throw MyException + } + } + } + } + "handle an exception happening during route building" in { + exceptionShouldBeHandled { + get { + handleExceptions(handler) { + throw MyException + } + } + } + } + "not interfere with alternative routes" in { + Get("/abc") ~> + get { + handleExceptions(handler)(reject) ~ { ctx ⇒ + throw MyException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + } + "not handle other exceptions" in { + Get("/abc") ~> + get { + handleExceptions(handler) { + throw new RuntimeException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + } + "always fall back to a default content type" in { + Get("/abc") ~> Accept(MediaTypes.`application/json`) ~> + get { + handleExceptions(handler) { + throw new RuntimeException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + + Get("/abc") ~> Accept(MediaTypes.`text/xml`, MediaRanges.`*/*`.withQValue(0f)) ~> + get { + handleExceptions(handler) { + throw new RuntimeException + } + } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + } + } + + def exceptionShouldBeHandled(route: Route) = + Get("/abc") ~> route ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "Pling! Plong! Something went wrong!!!" + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala new file mode 100644 index 0000000000..fb745e77df --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala @@ -0,0 +1,370 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import java.io.{ File, FileOutputStream } +import scala.concurrent.duration._ +import scala.concurrent.{ ExecutionContext, Future } +import scala.util.Properties +import org.scalatest.matchers.Matcher +import org.scalatest.{ Inside, Inspectors } +import akka.http.scaladsl.model.MediaTypes._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.impl.util._ + +class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Inside { + + override def testConfigSource = + """akka.http.scaladsl.routing { + | file-chunking-threshold-size = 16 + | file-chunking-chunk-size = 8 + | range-coalescing-threshold = 1 + |}""".stripMargin + + "getFromFile" should { + "reject non-GET requests" in { + Put() ~> getFromFile("some") ~> check { handled shouldEqual (false) } + } + "reject requests to non-existing files" in { + Get() ~> getFromFile("nonExistentFile") ~> check { handled shouldEqual (false) } + } + "reject requests to directories" in { + Get() ~> getFromFile(Properties.javaHome) ~> check { handled shouldEqual (false) } + } + "return the file content with the MediaType matching the file extension" in { + val file = File.createTempFile("akkaHttpTest", ".PDF") + try { + writeAllText("This is PDF", file) + Get() ~> getFromFile(file.getPath) ~> check { + mediaType shouldEqual `application/pdf` + definedCharset shouldEqual None + responseAs[String] shouldEqual "This is PDF" + headers should contain(`Last-Modified`(DateTime(file.lastModified))) + } + } finally file.delete + } + "return the file content with MediaType 'application/octet-stream' on unknown file extensions" in { + val file = File.createTempFile("akkaHttpTest", null) + try { + writeAllText("Some content", file) + Get() ~> getFromFile(file) ~> check { + mediaType shouldEqual `application/octet-stream` + responseAs[String] shouldEqual "Some content" + } + } finally file.delete + } + + "return a single range from a file" in { + val file = File.createTempFile("partialTest", null) + try { + writeAllText("ABCDEFGHIJKLMNOPQRSTUVWXYZ", file) + Get() ~> addHeader(Range(ByteRange(0, 10))) ~> getFromFile(file) ~> check { + status shouldEqual StatusCodes.PartialContent + headers should contain(`Content-Range`(ContentRange(0, 10, 26))) + responseAs[String] shouldEqual "ABCDEFGHIJK" + } + } finally file.delete + } + + "return multiple ranges from a file at once" in { + pending // FIXME: reactivate + val file = File.createTempFile("partialTest", null) + try { + writeAllText("ABCDEFGHIJKLMNOPQRSTUVWXYZ", file) + val rangeHeader = Range(ByteRange(1, 10), ByteRange.suffix(10)) + Get() ~> addHeader(rangeHeader) ~> getFromFile(file, ContentTypes.`text/plain`) ~> check { + status shouldEqual StatusCodes.PartialContent + header[`Content-Range`] shouldEqual None + mediaType.withParams(Map.empty) shouldEqual `multipart/byteranges` + + val parts = responseAs[Multipart.ByteRanges].toStrict(1.second).awaitResult(3.seconds).strictParts + parts.size shouldEqual 2 + parts(0).entity.data.utf8String shouldEqual "BCDEFGHIJK" + parts(1).entity.data.utf8String shouldEqual "QRSTUVWXYZ" + } + } finally file.delete + } + } + + "getFromResource" should { + "reject non-GET requests" in { + Put() ~> getFromResource("some") ~> check { handled shouldEqual (false) } + } + "reject requests to non-existing resources" in { + Get() ~> getFromResource("nonExistingResource") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources" in { + Get() ~> getFromResource("someDir") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources with trailing slash" in { + Get() ~> getFromResource("someDir/") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources from an Archive " in { + Get() ~> getFromResource("com/typesafe/config") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources from an Archive with trailing slash" in { + Get() ~> getFromResource("com/typesafe/config/") ~> check { handled shouldEqual (false) } + } + "return the resource content with the MediaType matching the file extension" in { + val route = getFromResource("sample.html") + + def runCheck() = + Get() ~> route ~> check { + mediaType shouldEqual `text/html` + forAtLeast(1, headers) { h ⇒ + inside(h) { + case `Last-Modified`(dt) ⇒ + DateTime(2011, 7, 1) should be < dt + dt.clicks should be < System.currentTimeMillis() + } + } + responseAs[String] shouldEqual "

Lorem ipsum!

" + } + + runCheck() + runCheck() // additional test to check that no internal state is kept + } + "return the resource content from an Archive" in { + Get() ~> getFromResource("com/typesafe/config/Config.class") ~> check { + mediaType shouldEqual `application/octet-stream` + responseEntity.toStrict(1.second).awaitResult(1.second).data.asByteBuffer.getInt shouldEqual 0xCAFEBABE + } + } + "return the file content with MediaType 'application/octet-stream' on unknown file extensions" in { + Get() ~> getFromResource("sample.xyz") ~> check { + mediaType shouldEqual `application/octet-stream` + responseAs[String] shouldEqual "XyZ" + } + } + } + + "getFromResourceDirectory" should { + "reject requests to non-existing resources" in { + Get("not/found") ~> getFromResourceDirectory("subDirectory") ~> check { handled shouldEqual (false) } + } + val verify = check { + mediaType shouldEqual `application/pdf` + responseAs[String] shouldEqual "123" + } + "return the resource content with the MediaType matching the file extension - example 1" in { + Get("empty.pdf") ~> getFromResourceDirectory("subDirectory") ~> verify + } + "return the resource content with the MediaType matching the file extension - example 2" in { + Get("empty.pdf") ~> getFromResourceDirectory("subDirectory/") ~> verify + } + "return the resource content with the MediaType matching the file extension - example 3" in { + Get("subDirectory/empty.pdf") ~> getFromResourceDirectory("") ~> verify + } + "return the resource content from an Archive" in { + Get("Config.class") ~> getFromResourceDirectory("com/typesafe/config") ~> check { + mediaType shouldEqual `application/octet-stream` + responseEntity.toStrict(1.second).awaitResult(1.second).data.asByteBuffer.getInt shouldEqual 0xCAFEBABE + } + } + "reject requests to directory resources" in { + Get() ~> getFromResourceDirectory("subDirectory") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources with trailing slash" in { + Get() ~> getFromResourceDirectory("subDirectory/") ~> check { handled shouldEqual (false) } + } + "reject requests to sub directory resources" in { + Get("sub") ~> getFromResourceDirectory("someDir") ~> check { handled shouldEqual (false) } + } + "reject requests to sub directory resources with trailing slash" in { + Get("sub/") ~> getFromResourceDirectory("someDir") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources from an Archive" in { + Get() ~> getFromResourceDirectory("com/typesafe/config") ~> check { handled shouldEqual (false) } + } + "reject requests to directory resources from an Archive with trailing slash" in { + Get() ~> getFromResourceDirectory("com/typesafe/config/") ~> check { handled shouldEqual (false) } + } + } + + "listDirectoryContents" should { + val base = new File(getClass.getClassLoader.getResource("").toURI).getPath + new File(base, "subDirectory/emptySub").mkdir() + def eraseDateTime(s: String) = s.replaceAll("""\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d""", "xxxx-xx-xx xx:xx:xx") + implicit val settings = RoutingSettings.default.copy(renderVanityFooter = false) + + "properly render a simple directory" in { + Get() ~> listDirectoryContents(base + "/someDir") ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of / + | + |

Index of /

+ |
+ |
+            |sub/             xxxx-xx-xx xx:xx:xx
+            |fileA.txt        xxxx-xx-xx xx:xx:xx            3  B
+            |fileB.xml        xxxx-xx-xx xx:xx:xx            0  B
+            |
+ |
+ | + | + |""" + } + } + } + "properly render a sub directory" in { + Get("/sub/") ~> listDirectoryContents(base + "/someDir") ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of /sub/ + | + |

Index of /sub/

+ |
+ |
+            |../
+            |file.html        xxxx-xx-xx xx:xx:xx            0  B
+            |
+ |
+ | + | + |""" + } + } + } + "properly render the union of several directories" in { + Get() ~> listDirectoryContents(base + "/someDir", base + "/subDirectory") ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of / + | + |

Index of /

+ |
+ |
+            |emptySub/        xxxx-xx-xx xx:xx:xx
+            |sub/             xxxx-xx-xx xx:xx:xx
+            |empty.pdf        xxxx-xx-xx xx:xx:xx            3  B
+            |fileA.txt        xxxx-xx-xx xx:xx:xx            3  B
+            |fileB.xml        xxxx-xx-xx xx:xx:xx            0  B
+            |
+ |
+ | + | + |""" + } + } + } + "properly render an empty sub directory with vanity footer" in { + val settings = 0 // shadow implicit + Get("/emptySub/") ~> listDirectoryContents(base + "/subDirectory") ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of /emptySub/ + | + |

Index of /emptySub/

+ |
+ |
+            |../
+            |
+ |
+ |
+ |rendered by Akka Http on xxxx-xx-xx xx:xx:xx + |
+ | + | + |""" + } + } + } + "properly render an empty top-level directory" in { + Get() ~> listDirectoryContents(base + "/subDirectory/emptySub") ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of / + | + |

Index of /

+ |
+ |
+            |(no files)
+            |
+ |
+ | + | + |""" + } + } + } + "properly render a simple directory with a path prefix" in { + Get("/files/") ~> pathPrefix("files")(listDirectoryContents(base + "/someDir")) ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of /files/ + | + |

Index of /files/

+ |
+ |
+            |sub/             xxxx-xx-xx xx:xx:xx
+            |fileA.txt        xxxx-xx-xx xx:xx:xx            3  B
+            |fileB.xml        xxxx-xx-xx xx:xx:xx            0  B
+            |
+ |
+ | + | + |""" + } + } + } + "properly render a sub directory with a path prefix" in { + Get("/files/sub/") ~> pathPrefix("files")(listDirectoryContents(base + "/someDir")) ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of /files/sub/ + | + |

Index of /files/sub/

+ |
+ |
+            |../
+            |file.html        xxxx-xx-xx xx:xx:xx            0  B
+            |
+ |
+ | + | + |""" + } + } + } + "properly render an empty top-level directory with a path prefix" in { + Get("/files/") ~> pathPrefix("files")(listDirectoryContents(base + "/subDirectory/emptySub")) ~> check { + eraseDateTime(responseAs[String]) shouldEqual prep { + """ + |Index of /files/ + | + |

Index of /files/

+ |
+ |
+            |(no files)
+            |
+ |
+ | + | + |""" + } + } + } + "reject requests to file resources" in { + Get() ~> listDirectoryContents(base + "subDirectory/empty.pdf") ~> check { handled shouldEqual (false) } + } + } + + def prep(s: String) = s.stripMarginWithNewline("\n") + + def writeAllText(text: String, file: File): Unit = { + val fos = new FileOutputStream(file) + try { + fos.write(text.getBytes("UTF-8")) + } finally fos.close() + } + + def evaluateTo[T](t: T, atMost: Duration = 100.millis)(implicit ec: ExecutionContext): Matcher[Future[T]] = + be(t).compose[Future[T]] { fut ⇒ + fut.awaitResult(atMost) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FormFieldDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FormFieldDirectivesSpec.scala new file mode 100644 index 0000000000..7b115b3b08 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FormFieldDirectivesSpec.scala @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.common.StrictForm +import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport +import akka.http.scaladsl.unmarshalling.Unmarshaller.HexInt +import akka.http.scaladsl.model._ +import MediaTypes._ + +class FormFieldDirectivesSpec extends RoutingSpec { + implicit val nodeSeqUnmarshaller = + ScalaXmlSupport.nodeSeqUnmarshaller(`text/xml`, `text/html`, `text/plain`) + + val nodeSeq: xml.NodeSeq = yes + val urlEncodedForm = FormData(Map("firstName" -> "Mike", "age" -> "42")) + val urlEncodedFormWithVip = FormData(Map("firstName" -> "Mike", "age" -> "42", "VIP" -> "true", "super" -> "no")) + val multipartForm = Multipart.FormData { + Map( + "firstName" -> HttpEntity("Mike"), + "age" -> HttpEntity(`text/xml`, "42"), + "VIPBoolean" -> HttpEntity("true")) + } + val multipartFormWithTextHtml = Multipart.FormData { + Map( + "firstName" -> HttpEntity("Mike"), + "age" -> HttpEntity(`text/xml`, "42"), + "VIP" -> HttpEntity(`text/html`, "yes"), + "super" -> HttpEntity("no")) + } + val multipartFormWithFile = Multipart.FormData( + Multipart.FormData.BodyPart.Strict("file", HttpEntity(MediaTypes.`text/xml`, "42"), + Map("filename" -> "age.xml"))) + + "The 'formFields' extraction directive" should { + "properly extract the value of www-urlencoded form fields" in { + Post("/", urlEncodedForm) ~> { + formFields('firstName, "age".as[Int], 'sex?, "VIP" ? false) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { responseAs[String] shouldEqual "Mike42Nonefalse" } + } + "properly extract the value of www-urlencoded form fields when an explicit unmarshaller is given" in { + Post("/", urlEncodedForm) ~> { + formFields('firstName, "age".as(HexInt), 'sex?, "VIP" ? false) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { responseAs[String] shouldEqual "Mike66Nonefalse" } + } + "properly extract the value of multipart form fields" in { + Post("/", multipartForm) ~> { + formFields('firstName, "age", 'sex?, "VIP" ? nodeSeq) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { responseAs[String] shouldEqual "Mike42Noneyes" } + } + "extract StrictForm.FileData from a multipart part" in { + Post("/", multipartFormWithFile) ~> { + formFields('file.as[StrictForm.FileData]) { + case StrictForm.FileData(name, HttpEntity.Strict(ct, data)) ⇒ + complete(s"type ${ct.mediaType} length ${data.length} filename ${name.get}") + } + } ~> check { responseAs[String] shouldEqual "type text/xml length 13 filename age.xml" } + } + "reject the request with a MissingFormFieldRejection if a required form field is missing" in { + Post("/", urlEncodedForm) ~> { + formFields('firstName, "age", 'sex, "VIP" ? false) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { rejection shouldEqual MissingFormFieldRejection("sex") } + } + "properly extract the value if only a urlencoded deserializer is available for a multipart field that comes without a" + + "Content-Type (or text/plain)" in { + Post("/", multipartForm) ~> { + formFields('firstName, "age", 'sex?, "VIPBoolean" ? false) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { + responseAs[String] shouldEqual "Mike42Nonetrue" + } + } + "work even if only a FromStringUnmarshaller is available for a multipart field with custom Content-Type" in { + Post("/", multipartFormWithTextHtml) ~> { + formFields(('firstName, "age", 'super ? false)) { (firstName, age, vip) ⇒ + complete(firstName + age + vip) + } + } ~> check { + responseAs[String] shouldEqual "Mike42false" + } + } + "work even if only a FromEntityUnmarshaller is available for a www-urlencoded field" in { + Post("/", urlEncodedFormWithVip) ~> { + formFields('firstName, "age", 'sex?, "super" ? nodeSeq) { (firstName, age, sex, vip) ⇒ + complete(firstName + age + sex + vip) + } + } ~> check { + responseAs[String] shouldEqual "Mike42Noneno" + } + } + } + "The 'formField' requirement directive" should { + "block requests that do not contain the required formField" in { + Post("/", urlEncodedForm) ~> { + formFields('name ! "Mr. Mike") { completeOk } + } ~> check { handled shouldEqual false } + } + "block requests that contain the required parameter but with an unmatching value" in { + Post("/", urlEncodedForm) ~> { + formFields('firstName ! "Pete") { completeOk } + } ~> check { handled shouldEqual false } + } + "let requests pass that contain the required parameter with its required value" in { + Post("/", urlEncodedForm) ~> { + formFields('firstName ! "Mike") { completeOk } + } ~> check { response shouldEqual Ok } + } + } + + "The 'formField' requirement with explicit unmarshaller directive" should { + "block requests that do not contain the required formField" in { + Post("/", urlEncodedForm) ~> { + formFields('oldAge.as(HexInt) ! 78) { completeOk } + } ~> check { handled shouldEqual false } + } + "block requests that contain the required parameter but with an unmatching value" in { + Post("/", urlEncodedForm) ~> { + formFields('age.as(HexInt) ! 78) { completeOk } + } ~> check { handled shouldEqual false } + } + "let requests pass that contain the required parameter with its required value" in { + Post("/", urlEncodedForm) ~> { + formFields('age.as(HexInt) ! 66 /* hex! */ ) { completeOk } + } ~> check { response shouldEqual Ok } + } + } + +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala new file mode 100644 index 0000000000..6616d9eca0 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.StatusCodes + +import scala.concurrent.Future + +class FutureDirectivesSpec extends RoutingSpec { + + class TestException(msg: String) extends Exception(msg) + object TestException extends Exception("XXX") + def throwTestException[T](msgPrefix: String): T ⇒ Nothing = t ⇒ throw new TestException(msgPrefix + t) + + implicit val exceptionHandler = ExceptionHandler { + case e: TestException ⇒ complete(StatusCodes.InternalServerError, "Oops. " + e) + } + + "The `onComplete` directive" should { + "unwrap a Future in the success case" in { + var i = 0 + def nextNumber() = { i += 1; i } + val route = onComplete(Future.successful(nextNumber())) { echoComplete } + Get() ~> route ~> check { + responseAs[String] shouldEqual "Success(1)" + } + Get() ~> route ~> check { + responseAs[String] shouldEqual "Success(2)" + } + } + "unwrap a Future in the failure case" in { + Get() ~> onComplete(Future.failed[String](new RuntimeException("no"))) { echoComplete } ~> check { + responseAs[String] shouldEqual "Failure(java.lang.RuntimeException: no)" + } + } + "catch an exception in the success case" in { + Get() ~> onComplete(Future.successful("ok")) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$TestException: EX when Success(ok)" + } + } + "catch an exception in the failure case" in { + Get() ~> onComplete(Future.failed[String](new RuntimeException("no"))) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$TestException: EX when Failure(java.lang.RuntimeException: no)" + } + } + } + + "The `onSuccess` directive" should { + "unwrap a Future in the success case" in { + Get() ~> onSuccess(Future.successful("yes")) { echoComplete } ~> check { + responseAs[String] shouldEqual "yes" + } + } + "propagate the exception in the failure case" in { + Get() ~> onSuccess(Future.failed(TestException)) { echoComplete } ~> check { + status shouldEqual StatusCodes.InternalServerError + } + } + "catch an exception in the success case" in { + Get() ~> onSuccess(Future.successful("ok")) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$TestException: EX when ok" + } + } + "catch an exception in the failure case" in { + Get() ~> onSuccess(Future.failed(TestException)) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + } + } + + "The `completeOrRecoverWith` directive" should { + "complete the request with the Future's value if the future succeeds" in { + Get() ~> completeOrRecoverWith(Future.successful("yes")) { echoComplete } ~> check { + responseAs[String] shouldEqual "yes" + } + } + "don't call the inner route if the Future succeeds" in { + Get() ~> completeOrRecoverWith(Future.successful("ok")) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.OK + responseAs[String] shouldEqual "ok" + } + } + "recover using the inner route if the Future fails" in { + val route = completeOrRecoverWith(Future.failed[String](TestException)) { + case e ⇒ complete(s"Exception occurred: ${e.getMessage}") + } + + Get() ~> route ~> check { + responseAs[String] shouldEqual "Exception occurred: XXX" + } + } + "catch an exception during recovery" in { + Get() ~> completeOrRecoverWith(Future.failed[String](TestException)) { throwTestException("EX when ") } ~> check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$TestException: EX when akka.http.scaladsl.server.directives.FutureDirectivesSpec$TestException$: XXX" + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala new file mode 100644 index 0000000000..48cc6f592b --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HeaderDirectivesSpec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import akka.http.scaladsl.model._ +import headers._ +import akka.http.scaladsl.server._ +import org.scalatest.Inside + +class HeaderDirectivesSpec extends RoutingSpec with Inside { + + "The headerValuePF directive" should { + lazy val myHeaderValue = headerValuePF { case Connection(tokens) ⇒ tokens.head } + + "extract the respective header value if a matching request header is present" in { + Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check { + responseAs[String] shouldEqual "close" + } + } + + "reject with an empty rejection set if no matching request header is present" in { + Get("/abc") ~> myHeaderValue { echoComplete } ~> check { rejections shouldEqual Nil } + } + + "reject with a MalformedHeaderRejection if the extract function throws an exception" in { + Get("/abc") ~> addHeader(Connection("close")) ~> { + (headerValuePF { case _ ⇒ sys.error("Naah!") }) { echoComplete } + } ~> check { + inside(rejection) { case MalformedHeaderRejection("Connection", "Naah!", _) ⇒ } + } + } + } + + "The headerValueByType directive" should { + lazy val route = + headerValueByType[Origin]() { origin ⇒ + complete(s"The first origin was ${origin.origins.head}") + } + "extract a header if the type is matching" in { + val originHeader = Origin(HttpOrigin("http://localhost:8080")) + Get("abc") ~> originHeader ~> route ~> check { + responseAs[String] shouldEqual "The first origin was http://localhost:8080" + } + } + "reject a request if no header of the given type is present" in { + Get("abc") ~> route ~> check { + inside(rejection) { + case MissingHeaderRejection("Origin") ⇒ + } + } + } + } + + "The optionalHeaderValue directive" should { + lazy val myHeaderValue = optionalHeaderValue { + case Connection(tokens) ⇒ Some(tokens.head) + case _ ⇒ None + } + + "extract the respective header value if a matching request header is present" in { + Get("/abc") ~> addHeader(Connection("close")) ~> myHeaderValue { echoComplete } ~> check { + responseAs[String] shouldEqual "Some(close)" + } + } + + "extract None if no matching request header is present" in { + Get("/abc") ~> myHeaderValue { echoComplete } ~> check { responseAs[String] shouldEqual "None" } + } + + "reject with a MalformedHeaderRejection if the extract function throws an exception" in { + Get("/abc") ~> addHeader(Connection("close")) ~> { + val myHeaderValue = optionalHeaderValue { case _ ⇒ sys.error("Naaah!") } + myHeaderValue { echoComplete } + } ~> check { + inside(rejection) { case MalformedHeaderRejection("Connection", "Naaah!", _) ⇒ } + } + } + } + + "The optionalHeaderValueByType directive" should { + val route = + optionalHeaderValueByType[Origin]() { + case Some(origin) ⇒ complete(s"The first origin was ${origin.origins.head}") + case None ⇒ complete("No Origin header found.") + } + "extract Some(header) if the type is matching" in { + val originHeader = Origin(HttpOrigin("http://localhost:8080")) + Get("abc") ~> originHeader ~> route ~> check { + responseAs[String] shouldEqual "The first origin was http://localhost:8080" + } + } + "extract None if no header of the given type is present" in { + Get("abc") ~> route ~> check { + responseAs[String] shouldEqual "No Origin header found." + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HostDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HostDirectivesSpec.scala new file mode 100644 index 0000000000..008857897a --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/HostDirectivesSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.headers.Host +import org.scalatest.FreeSpec + +class HostDirectivesSpec extends FreeSpec with GenericRoutingSpec { + "The 'host' directive" - { + "in its simple String form should" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("spray.com") { completeOk } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass" in { + Get() ~> Host("spray.io") ~> { + host("spray.com", "spray.io") { completeOk } + } ~> check { response shouldEqual Ok } + } + } + + "in its simple RegEx form" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("hairspray.*".r) { echoComplete } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass and extract the full host" in { + Get() ~> Host("spray.io") ~> { + host("spra.*".r) { echoComplete } + } ~> check { responseAs[String] shouldEqual "spray.io" } + } + } + + "in its group RegEx form" - { + "block requests to unmatched hosts" in { + Get() ~> Host("spray.io") ~> { + host("hairspray(.*)".r) { echoComplete } + } ~> check { handled shouldEqual false } + } + + "let requests to matching hosts pass and extract the full host" in { + Get() ~> Host("spray.io") ~> { + host("spra(.*)".r) { echoComplete } + } ~> check { responseAs[String] shouldEqual "y.io" } + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala new file mode 100644 index 0000000000..f27650a7fc --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.xml.NodeSeq +import org.scalatest.Inside +import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport +import akka.http.scaladsl.unmarshalling._ +import akka.http.scaladsl.marshalling._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ +import MediaTypes._ +import HttpCharsets._ +import headers._ +import spray.json.DefaultJsonProtocol._ + +class MarshallingDirectivesSpec extends RoutingSpec with Inside { + import ScalaXmlSupport._ + + private val iso88592 = HttpCharsets.getForKey("iso-8859-2").get + implicit val IntUnmarshaller: FromEntityUnmarshaller[Int] = + nodeSeqUnmarshaller(ContentTypeRange(`text/xml`, iso88592), `text/html`) map { + case NodeSeq.Empty ⇒ throw Unmarshaller.NoContentException + case x ⇒ { val i = x.text.toInt; require(i >= 0); i } + } + + implicit val IntMarshaller: ToEntityMarshaller[Int] = + Marshaller.oneOf(ContentType(`application/xhtml+xml`), ContentType(`text/xml`, `UTF-8`)) { contentType ⇒ + nodeSeqMarshaller(contentType).wrap(contentType) { (i: Int) ⇒ { i } } + } + + "The 'entityAs' directive" should { + "extract an object from the requests entity using the in-scope Unmarshaller" in { + Put("/",

cool

) ~> { + entity(as[NodeSeq]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "

cool

" } + } + "return a RequestEntityExpectedRejection rejection if the request has no entity" in { + Put() ~> { + entity(as[Int]) { echoComplete } + } ~> check { rejection shouldEqual RequestEntityExpectedRejection } + } + "return an UnsupportedRequestContentTypeRejection if no matching unmarshaller is in scope" in { + Put("/", HttpEntity(`text/css`, "

cool

")) ~> { + entity(as[NodeSeq]) { echoComplete } + } ~> check { + rejection shouldEqual UnsupportedRequestContentTypeRejection(Set(`text/xml`, `application/xml`, `text/html`, `application/xhtml+xml`)) + } + Put("/", HttpEntity(ContentType(`text/xml`, `UTF-16`), "26")) ~> { + entity(as[Int]) { echoComplete } + } ~> check { + rejection shouldEqual UnsupportedRequestContentTypeRejection(Set(ContentTypeRange(`text/xml`, iso88592), `text/html`)) + } + } + "cancel UnsupportedRequestContentTypeRejections if a subsequent `entity` directive succeeds" in { + Put("/", HttpEntity(`text/plain`, "yeah")) ~> { + entity(as[NodeSeq]) { _ ⇒ completeOk } ~ + entity(as[String]) { _ ⇒ validate(false, "Problem") { completeOk } } + } ~> check { rejection shouldEqual ValidationRejection("Problem") } + } + "return a ValidationRejection if the request entity is semantically invalid (IllegalArgumentException)" in { + Put("/", HttpEntity(ContentType(`text/xml`, iso88592), "-3")) ~> { + entity(as[Int]) { _ ⇒ completeOk } + } ~> check { + inside(rejection) { + case ValidationRejection("requirement failed", Some(_: IllegalArgumentException)) ⇒ + } + } + } + "return a MalformedRequestContentRejection if unmarshalling failed due to a not further classified error" in { + Put("/", HttpEntity(`text/xml`, " { + entity(as[NodeSeq]) { _ ⇒ completeOk } + } ~> check { + rejection shouldEqual MalformedRequestContentRejection( + "XML document structures must start and end within the same entity.", None) + } + } + "extract an Option[T] from the requests entity using the in-scope Unmarshaller" in { + Put("/",

cool

) ~> { + entity(as[Option[NodeSeq]]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "Some(

cool

)" } + } + "extract an Option[T] as None if the request has no entity" in { + Put() ~> { + entity(as[Option[Int]]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "None" } + } + "return an UnsupportedRequestContentTypeRejection if no matching unmarshaller is in scope (for Option[T]s)" in { + Put("/", HttpEntity(`text/css`, "

cool

")) ~> { + entity(as[Option[NodeSeq]]) { echoComplete } + } ~> check { + rejection shouldEqual UnsupportedRequestContentTypeRejection(Set(`text/xml`, `application/xml`, `text/html`, `application/xhtml+xml`)) + } + } + "properly extract with a super-unmarshaller" in { + case class Person(name: String) + val jsonUnmarshaller: FromEntityUnmarshaller[Person] = jsonFormat1(Person) + val xmlUnmarshaller: FromEntityUnmarshaller[Person] = + ScalaXmlSupport.nodeSeqUnmarshaller(`text/xml`).map(seq ⇒ Person(seq.text)) + + implicit val unmarshaller = Unmarshaller.firstOf(jsonUnmarshaller, xmlUnmarshaller) + + val route = entity(as[Person]) { echoComplete } + + Put("/", HttpEntity(`text/xml`, "Peter Xml")) ~> route ~> check { + responseAs[String] shouldEqual "Person(Peter Xml)" + } + Put("/", HttpEntity(`application/json`, """{ "name": "Paul Json" }""")) ~> route ~> check { + responseAs[String] shouldEqual "Person(Paul Json)" + } + Put("/", HttpEntity(`text/plain`, """name = Sir Text }""")) ~> route ~> check { + rejection shouldEqual UnsupportedRequestContentTypeRejection(Set(`application/json`, `text/xml`)) + } + } + } + + "The 'completeWith' directive" should { + "provide a completion function converting custom objects to an HttpEntity using the in-scope marshaller" in { + Get() ~> completeWith(instanceOf[Int]) { prod ⇒ prod(42) } ~> check { + responseEntity shouldEqual HttpEntity(ContentType(`application/xhtml+xml`, `UTF-8`), "42") + } + } + "return a UnacceptedResponseContentTypeRejection rejection if no acceptable marshaller is in scope" in { + Get() ~> Accept(`text/css`) ~> completeWith(instanceOf[Int]) { prod ⇒ prod(42) } ~> check { + rejection shouldEqual UnacceptedResponseContentTypeRejection(Set(`application/xhtml+xml`, ContentType(`text/xml`, `UTF-8`))) + } + } + "convert the response content to an accepted charset" in { + Get() ~> `Accept-Charset`(`UTF-8`) ~> completeWith(instanceOf[String]) { prod ⇒ prod("Hällö") } ~> check { + responseEntity shouldEqual HttpEntity(ContentType(`text/plain`, `UTF-8`), "Hällö") + } + } + } + + "The 'handleWith' directive" should { + def times2(x: Int) = x * 2 + + "support proper round-trip content unmarshalling/marshalling to and from a function" in ( + Put("/", HttpEntity(`text/html`, "42")) ~> Accept(`text/xml`) ~> handleWith(times2) + ~> check { responseEntity shouldEqual HttpEntity(ContentType(`text/xml`, `UTF-8`), "84") }) + + "result in UnsupportedRequestContentTypeRejection rejection if there is no unmarshaller supporting the requests charset" in ( + Put("/", HttpEntity(`text/xml`, "42")) ~> Accept(`text/xml`) ~> handleWith(times2) + ~> check { + rejection shouldEqual UnsupportedRequestContentTypeRejection(Set(ContentTypeRange(`text/xml`, iso88592), `text/html`)) + }) + + "result in an UnacceptedResponseContentTypeRejection rejection if there is no marshaller supporting the requests Accept-Charset header" in ( + Put("/", HttpEntity(`text/html`, "42")) ~> addHeaders(Accept(`text/xml`), `Accept-Charset`(`UTF-16`)) ~> + handleWith(times2) ~> check { + rejection shouldEqual UnacceptedResponseContentTypeRejection(Set(`application/xhtml+xml`, ContentType(`text/xml`, `UTF-8`))) + }) + } + + "The marshalling infrastructure for JSON" should { + import spray.json._ + case class Foo(name: String) + implicit val fooFormat = jsonFormat1(Foo) + val foo = Foo("Hällö") + + "render JSON with UTF-8 encoding if no `Accept-Charset` request header is present" in { + Get() ~> complete(foo) ~> check { + responseEntity shouldEqual HttpEntity(ContentType(`application/json`, `UTF-8`), foo.toJson.prettyPrint) + } + } + "reject JSON rendering if an `Accept-Charset` request header requests a non-UTF-8 encoding" in { + Get() ~> `Accept-Charset`(`ISO-8859-1`) ~> complete(foo) ~> check { + rejection shouldEqual UnacceptedResponseContentTypeRejection(Set(ContentType(`application/json`))) + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MethodDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MethodDirectivesSpec.scala new file mode 100644 index 0000000000..a8d84eae14 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MethodDirectivesSpec.scala @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import akka.http.scaladsl.model.{ StatusCodes, HttpMethods } +import akka.http.scaladsl.server._ + +class MethodDirectivesSpec extends RoutingSpec { + + "get | put" should { + lazy val getOrPut = (get | put) { completeOk } + + "block POST requests" in { + Post() ~> getOrPut ~> check { handled shouldEqual false } + } + "let GET requests pass" in { + Get() ~> getOrPut ~> check { response shouldEqual Ok } + } + "let PUT requests pass" in { + Put() ~> getOrPut ~> check { response shouldEqual Ok } + } + } + + "two failed `get` directives" should { + "only result in a single Rejection" in { + Put() ~> { + get { completeOk } ~ + get { completeOk } + } ~> check { + rejections shouldEqual List(MethodRejection(HttpMethods.GET)) + } + } + } + + "overrideMethodWithParameter" should { + "change the request method" in { + Get("/?_method=put") ~> overrideMethodWithParameter("_method") { + get { complete("GET") } ~ + put { complete("PUT") } + } ~> check { responseAs[String] shouldEqual "PUT" } + } + "not affect the request when not specified" in { + Get() ~> overrideMethodWithParameter("_method") { + get { complete("GET") } ~ + put { complete("PUT") } + } ~> check { responseAs[String] shouldEqual "GET" } + } + "complete with 501 Not Implemented when not a valid method" in { + Get("/?_method=hallo") ~> overrideMethodWithParameter("_method") { + get { complete("GET") } ~ + put { complete("PUT") } + } ~> check { status shouldEqual StatusCodes.NotImplemented } + } + } + + "MethodRejections under a successful match" should { + "be cancelled if the match happens after the rejection" in { + Put() ~> { + get { completeOk } ~ + put { reject(RequestEntityExpectedRejection) } + } ~> check { + rejections shouldEqual List(RequestEntityExpectedRejection) + } + } + "be cancelled if the match happens after the rejection (example 2)" in { + Put() ~> { + (get & complete(Ok)) ~ (put & reject(RequestEntityExpectedRejection)) + } ~> check { + rejections shouldEqual List(RequestEntityExpectedRejection) + } + } + "be cancelled if the match happens before the rejection" in { + Put() ~> { + put { reject(RequestEntityExpectedRejection) } ~ get { completeOk } + } ~> check { + rejections shouldEqual List(RequestEntityExpectedRejection) + } + } + "be cancelled if the match happens before the rejection (example 2)" in { + Put() ~> { + (put & reject(RequestEntityExpectedRejection)) ~ (get & complete(Ok)) + } ~> check { + rejections shouldEqual List(RequestEntityExpectedRejection) + } + } + } +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MiscDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MiscDirectivesSpec.scala new file mode 100644 index 0000000000..3c4b066e40 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MiscDirectivesSpec.scala @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import headers._ +import HttpMethods._ +import MediaTypes._ +import Uri._ + +class MiscDirectivesSpec extends RoutingSpec { + + "the extractClientIP directive" should { + "extract from a X-Forwarded-For header" in { + Get() ~> addHeaders(`X-Forwarded-For`("2.3.4.5"), RawHeader("x-real-ip", "1.2.3.4")) ~> { + extractClientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "2.3.4.5" } + } + "extract from a Remote-Address header" in { + Get() ~> addHeaders(RawHeader("x-real-ip", "1.2.3.4"), `Remote-Address`(RemoteAddress("5.6.7.8"))) ~> { + extractClientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "5.6.7.8" } + } + "extract from a X-Real-IP header" in { + Get() ~> addHeader(RawHeader("x-real-ip", "1.2.3.4")) ~> { + extractClientIP { echoComplete } + } ~> check { responseAs[String] shouldEqual "1.2.3.4" } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala new file mode 100644 index 0000000000..eded50e662 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import org.scalatest.{ FreeSpec, Inside } +import akka.http.scaladsl.unmarshalling.Unmarshaller.HexInt + +class ParameterDirectivesSpec extends FreeSpec with GenericRoutingSpec with Inside { + + "when used with 'as[Int]' the parameter directive should" - { + "extract a parameter value as Int" in { + Get("/?amount=123") ~> { + parameter('amount.as[Int]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "123" } + } + "cause a MalformedQueryParamRejection on illegal Int values" in { + Get("/?amount=1x3") ~> { + parameter('amount.as[Int]) { echoComplete } + } ~> check { + inside(rejection) { + case MalformedQueryParamRejection("amount", "'1x3' is not a valid 32-bit signed integer value", Some(_)) ⇒ + } + } + } + "supply typed default values" in { + Get() ~> { + parameter('amount ? 45) { echoComplete } + } ~> check { responseAs[String] shouldEqual "45" } + } + "create typed optional parameters that" - { + "extract Some(value) when present" in { + Get("/?amount=12") ~> { + parameter("amount".as[Int]?) { echoComplete } + } ~> check { responseAs[String] shouldEqual "Some(12)" } + } + "extract None when not present" in { + Get() ~> { + parameter("amount".as[Int]?) { echoComplete } + } ~> check { responseAs[String] shouldEqual "None" } + } + "cause a MalformedQueryParamRejection on illegal Int values" in { + Get("/?amount=x") ~> { + parameter("amount".as[Int]?) { echoComplete } + } ~> check { + inside(rejection) { + case MalformedQueryParamRejection("amount", "'x' is not a valid 32-bit signed integer value", Some(_)) ⇒ + } + } + } + } + } + + "when used with 'as(HexInt)' the parameter directive should" - { + "extract parameter values as Int" in { + Get("/?amount=1f") ~> { + parameter('amount.as(HexInt)) { echoComplete } + } ~> check { responseAs[String] shouldEqual "31" } + } + "cause a MalformedQueryParamRejection on illegal Int values" in { + Get("/?amount=1x3") ~> { + parameter('amount.as(HexInt)) { echoComplete } + } ~> check { + inside(rejection) { + case MalformedQueryParamRejection("amount", "'1x3' is not a valid 32-bit hexadecimal integer value", Some(_)) ⇒ + } + } + } + "supply typed default values" in { + Get() ~> { + parameter('amount.as(HexInt) ? 45) { echoComplete } + } ~> check { responseAs[String] shouldEqual "45" } + } + "create typed optional parameters that" - { + "extract Some(value) when present" in { + Get("/?amount=A") ~> { + parameter("amount".as(HexInt)?) { echoComplete } + } ~> check { responseAs[String] shouldEqual "Some(10)" } + } + "extract None when not present" in { + Get() ~> { + parameter("amount".as(HexInt)?) { echoComplete } + } ~> check { responseAs[String] shouldEqual "None" } + } + "cause a MalformedQueryParamRejection on illegal Int values" in { + Get("/?amount=x") ~> { + parameter("amount".as(HexInt)?) { echoComplete } + } ~> check { + inside(rejection) { + case MalformedQueryParamRejection("amount", "'x' is not a valid 32-bit hexadecimal integer value", Some(_)) ⇒ + } + } + } + } + } + + "when used with 'as[Boolean]' the parameter directive should" - { + "extract parameter values as Boolean" in { + Get("/?really=true") ~> { + parameter('really.as[Boolean]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "true" } + Get("/?really=no") ~> { + parameter('really.as[Boolean]) { echoComplete } + } ~> check { responseAs[String] shouldEqual "false" } + } + "extract optional parameter values as Boolean" in { + Get() ~> { + parameter('really.as[Boolean] ? false) { echoComplete } + } ~> check { responseAs[String] shouldEqual "false" } + } + "cause a MalformedQueryParamRejection on illegal Boolean values" in { + Get("/?really=absolutely") ~> { + parameter('really.as[Boolean]) { echoComplete } + } ~> check { + inside(rejection) { + case MalformedQueryParamRejection("really", "'absolutely' is not a valid Boolean value", None) ⇒ + } + } + } + } + + "The 'parameters' extraction directive should" - { + "extract the value of given parameters" in { + Get("/?name=Parsons&FirstName=Ellen") ~> { + parameters("name", 'FirstName) { (name, firstName) ⇒ + complete(firstName + name) + } + } ~> check { responseAs[String] shouldEqual "EllenParsons" } + } + "correctly extract an optional parameter" in { + Get("/?foo=bar") ~> parameters('foo ?) { echoComplete } ~> check { responseAs[String] shouldEqual "Some(bar)" } + Get("/?foo=bar") ~> parameters('baz ?) { echoComplete } ~> check { responseAs[String] shouldEqual "None" } + } + "ignore additional parameters" in { + Get("/?name=Parsons&FirstName=Ellen&age=29") ~> { + parameters("name", 'FirstName) { (name, firstName) ⇒ + complete(firstName + name) + } + } ~> check { responseAs[String] shouldEqual "EllenParsons" } + } + "reject the request with a MissingQueryParamRejection if a required parameter is missing" in { + Get("/?name=Parsons&sex=female") ~> { + parameters('name, 'FirstName, 'age) { (name, firstName, age) ⇒ + completeOk + } + } ~> check { rejection shouldEqual MissingQueryParamRejection("FirstName") } + } + "supply the default value if an optional parameter is missing" in { + Get("/?name=Parsons&FirstName=Ellen") ~> { + parameters("name"?, 'FirstName, 'age ? "29", 'eyes?) { (name, firstName, age, eyes) ⇒ + complete(firstName + name + age + eyes) + } + } ~> check { responseAs[String] shouldEqual "EllenSome(Parsons)29None" } + } + } + + "The 'parameter' requirement directive should" - { + "block requests that do not contain the required parameter" in { + Get("/person?age=19") ~> { + parameter('nose ! "large") { completeOk } + } ~> check { handled shouldEqual false } + } + "block requests that contain the required parameter but with an unmatching value" in { + Get("/person?age=19&nose=small") ~> { + parameter('nose ! "large") { completeOk } + } ~> check { handled shouldEqual false } + } + "let requests pass that contain the required parameter with its required value" in { + Get("/person?nose=large&eyes=blue") ~> { + parameter('nose ! "large") { completeOk } + } ~> check { response shouldEqual Ok } + } + "be useable for method tunneling" in { + val route = { + (post | parameter('method ! "post")) { complete("POST") } ~ + get { complete("GET") } + } + Get("/?method=post") ~> route ~> check { responseAs[String] shouldEqual "POST" } + Post() ~> route ~> check { responseAs[String] shouldEqual "POST" } + Get() ~> route ~> check { responseAs[String] shouldEqual "GET" } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/PathDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/PathDirectivesSpec.scala new file mode 100644 index 0000000000..812167e6d9 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/PathDirectivesSpec.scala @@ -0,0 +1,339 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import akka.http.scaladsl.server._ +import org.scalatest.Inside + +class PathDirectivesSpec extends RoutingSpec with Inside { + val echoUnmatchedPath = extractUnmatchedPath { echoComplete } + def echoCaptureAndUnmatchedPath[T]: T ⇒ Route = + capture ⇒ ctx ⇒ ctx.complete(capture.toString + ":" + ctx.unmatchedPath) + + """path("foo")""" should { + val test = testFor(path("foo") { echoUnmatchedPath }) + "reject [/bar]" in test() + "reject [/foobar]" in test() + "reject [/foo/bar]" in test() + "accept [/foo] and clear the unmatchedPath" in test("") + "reject [/foo/]" in test() + } + + """path("foo" /)""" should { + val test = testFor(path("foo" /) { echoUnmatchedPath }) + "reject [/foo]" in test() + "accept [/foo/] and clear the unmatchedPath" in test("") + } + + """path("")""" should { + val test = testFor(path("") { echoUnmatchedPath }) + "reject [/foo]" in test() + "accept [/] and clear the unmatchedPath" in test("") + } + + """pathPrefix("foo")""" should { + val test = testFor(pathPrefix("foo") { echoUnmatchedPath }) + "reject [/bar]" in test() + "accept [/foobar]" in test("bar") + "accept [/foo/bar]" in test("/bar") + "accept [/foo] and clear the unmatchedPath" in test("") + "accept [/foo/] and clear the unmatchedPath" in test("/") + } + + """pathPrefix("foo" / "bar")""" should { + val test = testFor(pathPrefix("foo" / "bar") { echoUnmatchedPath }) + "reject [/bar]" in test() + "accept [/foo/bar]" in test("") + "accept [/foo/bar/baz]" in test("/baz") + } + + """pathPrefix("ab[cd]+".r)""" should { + val test = testFor(pathPrefix("ab[cd]+".r) { echoCaptureAndUnmatchedPath }) + "reject [/bar]" in test() + "reject [/ab/cd]" in test() + "reject [/abcdef]" in test("abcd:ef") + "reject [/abcdd/ef]" in test("abcdd:/ef") + } + + """pathPrefix("ab(cd)".r)""" should { + val test = testFor(pathPrefix("ab(cd)+".r) { echoCaptureAndUnmatchedPath }) + "reject [/bar]" in test() + "reject [/ab/cd]" in test() + "reject [/abcdef]" in test("cd:ef") + "reject [/abcde/fg]" in test("cd:e/fg") + } + + "pathPrefix(regex)" should { + "fail when the regex contains more than one group" in { + an[IllegalArgumentException] must be thrownBy path("a(b+)(c+)".r) { echoCaptureAndUnmatchedPath } + } + } + + "pathPrefix(IntNumber)" should { + val test = testFor(pathPrefix(IntNumber) { echoCaptureAndUnmatchedPath }) + "accept [/23]" in test("23:") + "accept [/12345yes]" in test("12345:yes") + "reject [/]" in test() + "reject [/abc]" in test() + "reject [/2147483648]" in test() // > Int.MaxValue + } + + "pathPrefix(CustomShortNumber)" should { + object CustomShortNumber extends NumberMatcher[Short](Short.MaxValue, 10) { + def fromChar(c: Char) = fromDecimalChar(c) + } + + val test = testFor(pathPrefix(CustomShortNumber) { echoCaptureAndUnmatchedPath }) + "accept [/23]" in test("23:") + "accept [/12345yes]" in test("12345:yes") + "reject [/]" in test() + "reject [/abc]" in test() + "reject [/33000]" in test() // > Short.MaxValue + } + + "pathPrefix(JavaUUID)" should { + val test = testFor(pathPrefix(JavaUUID) { echoCaptureAndUnmatchedPath }) + "accept [/bdea8652-f26c-40ca-8157-0b96a2a8389d]" in test("bdea8652-f26c-40ca-8157-0b96a2a8389d:") + "accept [/bdea8652-f26c-40ca-8157-0b96a2a8389dyes]" in test("bdea8652-f26c-40ca-8157-0b96a2a8389d:yes") + "reject [/]" in test() + "reject [/abc]" in test() + } + + "pathPrefix(Map(\"red\" -> 1, \"green\" -> 2, \"blue\" -> 3))" should { + val test = testFor(pathPrefix(Map("red" -> 1, "green" -> 2, "blue" -> 3)) { echoCaptureAndUnmatchedPath }) + "accept [/green]" in test("2:") + "accept [/redsea]" in test("1:sea") + "reject [/black]" in test() + } + + "pathPrefix(Map.empty)" should { + val test = testFor(pathPrefix(Map[String, Int]()) { echoCaptureAndUnmatchedPath }) + "reject [/black]" in test() + } + + "pathPrefix(Segment)" should { + val test = testFor(pathPrefix(Segment) { echoCaptureAndUnmatchedPath }) + "accept [/abc]" in test("abc:") + "accept [/abc/]" in test("abc:/") + "accept [/abc/def]" in test("abc:/def") + "reject [/]" in test() + } + + "pathPrefix(Segments)" should { + val test = testFor(pathPrefix(Segments) { echoCaptureAndUnmatchedPath }) + "accept [/]" in test("List():") + "accept [/a/b/c]" in test("List(a, b, c):") + "accept [/a/b/c/]" in test("List(a, b, c):/") + } + + """pathPrefix(separateOnSlashes("a/b"))""" should { + val test = testFor(pathPrefix(separateOnSlashes("a/b")) { echoUnmatchedPath }) + "accept [/a/b]" in test("") + "accept [/a/b/]" in test("/") + "accept [/a/c]" in test() + } + """pathPrefix(separateOnSlashes("abc"))""" should { + val test = testFor(pathPrefix(separateOnSlashes("abc")) { echoUnmatchedPath }) + "accept [/abc]" in test("") + "accept [/abcdef]" in test("def") + "accept [/ab]" in test() + } + + """pathPrefixTest("a" / Segment ~ Slash)""" should { + val test = testFor(pathPrefixTest("a" / Segment ~ Slash) { echoCaptureAndUnmatchedPath }) + "accept [/a/bc/]" in test("bc:/a/bc/") + "accept [/a/bc]" in test() + "accept [/a/]" in test() + } + + """pathSuffix("edit" / Segment)""" should { + val test = testFor(pathSuffix("edit" / Segment) { echoCaptureAndUnmatchedPath }) + "accept [/orders/123/edit]" in test("123:/orders/") + "accept [/orders/123/ed]" in test() + "accept [/edit]" in test() + } + + """pathSuffix("foo" / "bar" ~ "baz")""" should { + val test = testFor(pathSuffix("foo" / "bar" ~ "baz") { echoUnmatchedPath }) + "accept [/orders/barbaz/foo]" in test("/orders/") + "accept [/orders/bazbar/foo]" in test() + } + + "pathSuffixTest(Slash)" should { + val test = testFor(pathSuffixTest(Slash) { echoUnmatchedPath }) + "accept [/]" in test("/") + "accept [/foo/]" in test("/foo/") + "accept [/foo]" in test() + } + + """pathPrefix("foo" | "bar")""" should { + val test = testFor(pathPrefix("foo" | "bar") { echoUnmatchedPath }) + "accept [/foo]" in test("") + "accept [/foops]" in test("ps") + "accept [/bar]" in test("") + "reject [/baz]" in test() + } + + """pathSuffix(!"foo")""" should { + val test = testFor(pathSuffix(!"foo") { echoUnmatchedPath }) + "accept [/bar]" in test("/bar") + "reject [/foo]" in test() + } + + "pathPrefix(IntNumber?)" should { + val test = testFor(pathPrefix(IntNumber?) { echoCaptureAndUnmatchedPath }) + "accept [/12]" in test("Some(12):") + "accept [/12a]" in test("Some(12):a") + "accept [/foo]" in test("None:foo") + } + + """pathPrefix("foo"?)""" should { + val test = testFor(pathPrefix("foo"?) { echoUnmatchedPath }) + "accept [/foo]" in test("") + "accept [/fool]" in test("l") + "accept [/bar]" in test("bar") + } + + """pathPrefix("foo") & pathEnd""" should { + val test = testFor((pathPrefix("foo") & pathEnd) { echoUnmatchedPath }) + "reject [/foobar]" in test() + "reject [/foo/bar]" in test() + "accept [/foo] and clear the unmatchedPath" in test("") + "reject [/foo/]" in test() + } + + """pathPrefix("foo") & pathEndOrSingleSlash""" should { + val test = testFor((pathPrefix("foo") & pathEndOrSingleSlash) { echoUnmatchedPath }) + "reject [/foobar]" in test() + "reject [/foo/bar]" in test() + "accept [/foo] and clear the unmatchedPath" in test("") + "accept [/foo/] and clear the unmatchedPath" in test("") + } + + """pathPrefix(IntNumber.repeat(separator = "."))""" should { + { + val test = testFor(pathPrefix(IntNumber.repeat(min = 2, max = 5, separator = ".")) { echoCaptureAndUnmatchedPath }) + "reject [/foo]" in test() + "reject [/1foo]" in test() + "reject [/1.foo]" in test() + "accept [/1.2foo]" in test("List(1, 2):foo") + "accept [/1.2.foo]" in test("List(1, 2):.foo") + "accept [/1.2.3foo]" in test("List(1, 2, 3):foo") + "accept [/1.2.3.foo]" in test("List(1, 2, 3):.foo") + "accept [/1.2.3.4foo]" in test("List(1, 2, 3, 4):foo") + "accept [/1.2.3.4.foo]" in test("List(1, 2, 3, 4):.foo") + "accept [/1.2.3.4.5foo]" in test("List(1, 2, 3, 4, 5):foo") + "accept [/1.2.3.4.5.foo]" in test("List(1, 2, 3, 4, 5):.foo") + "accept [/1.2.3.4.5.6foo]" in test("List(1, 2, 3, 4, 5):.6foo") + "accept [/1.2.3.]" in test("List(1, 2, 3):.") + "accept [/1.2.3/]" in test("List(1, 2, 3):/") + "accept [/1.2.3./]" in test("List(1, 2, 3):./") + } + { + val test = testFor(pathPrefix(IntNumber.repeat(2, ".")) { echoCaptureAndUnmatchedPath }) + "reject [/bar]" in test() + "reject [/1bar]" in test() + "reject [/1.bar]" in test() + "accept [/1.2bar]" in test("List(1, 2):bar") + "accept [/1.2.bar]" in test("List(1, 2):.bar") + "accept [/1.2.3bar]" in test("List(1, 2):.3bar") + } + } + + "PathMatchers" should { + { + val test = testFor(path(Rest.tmap { case Tuple1(s) ⇒ Tuple1(s.split('-').toList) }) { echoComplete }) + "support the hmap modifier in accept [/yes-no]" in test("List(yes, no)") + } + { + val test = testFor(path(Rest.map(_.split('-').toList)) { echoComplete }) + "support the map modifier in accept [/yes-no]" in test("List(yes, no)") + } + { + val test = testFor(path(Rest.tflatMap { case Tuple1(s) ⇒ Some(s).filter("yes" ==).map(x ⇒ Tuple1(x)) }) { echoComplete }) + "support the hflatMap modifier in accept [/yes]" in test("yes") + "support the hflatMap modifier in reject [/blub]" in test() + } + { + val test = testFor(path(Rest.flatMap(s ⇒ Some(s).filter("yes" ==))) { echoComplete }) + "support the flatMap modifier in accept [/yes]" in test("yes") + "support the flatMap modifier reject [/blub]" in test() + } + } + + implicit class WithIn(str: String) { + def in(f: String ⇒ Unit) = convertToWordSpecStringWrapper(str) in f(str) + def in(body: ⇒ Unit) = convertToWordSpecStringWrapper(str) in body + } + + case class testFor(route: Route) { + def apply(expectedResponse: String = null): String ⇒ Unit = exampleString ⇒ + "\\[([^\\]]+)\\]".r.findFirstMatchIn(exampleString) match { + case Some(uri) ⇒ Get(uri.group(1)) ~> route ~> check { + if (expectedResponse eq null) handled shouldEqual false + else responseAs[String] shouldEqual expectedResponse + } + case None ⇒ failTest("Example '" + exampleString + "' doesn't contain a test uri") + } + } + + import akka.http.scaladsl.model.StatusCodes._ + + "the `redirectToTrailingSlashIfMissing` directive" should { + val route = redirectToTrailingSlashIfMissing(Found) { completeOk } + + "pass if the request path already has a trailing slash" in { + Get("/foo/bar/") ~> route ~> check { response shouldEqual Ok } + } + + "redirect if the request path doesn't have a trailing slash" in { + Get("/foo/bar") ~> route ~> checkRedirectTo("/foo/bar/") + } + + "preserves the query and the frag when redirect" in { + Get("/foo/bar?query#frag") ~> route ~> checkRedirectTo("/foo/bar/?query#frag") + } + + "redirect with the given redirection status code" in { + Get("/foo/bar") ~> + redirectToTrailingSlashIfMissing(MovedPermanently) { completeOk } ~> + check { status shouldEqual MovedPermanently } + } + } + + "the `redirectToNoTrailingSlashIfPresent` directive" should { + val route = redirectToNoTrailingSlashIfPresent(Found) { completeOk } + + "pass if the request path already doesn't have a trailing slash" in { + Get("/foo/bar") ~> route ~> check { response shouldEqual Ok } + } + + "redirect if the request path has a trailing slash" in { + Get("/foo/bar/") ~> route ~> checkRedirectTo("/foo/bar") + } + + "preserves the query and the frag when redirect" in { + Get("/foo/bar/?query#frag") ~> route ~> checkRedirectTo("/foo/bar?query#frag") + } + + "redirect with the given redirection status code" in { + Get("/foo/bar/") ~> + redirectToNoTrailingSlashIfPresent(MovedPermanently) { completeOk } ~> + check { status shouldEqual MovedPermanently } + } + } + + import akka.http.scaladsl.model.headers.Location + import akka.http.scaladsl.model.Uri + + private def checkRedirectTo(expectedUri: Uri) = + check { + status shouldBe a[Redirection] + inside(header[Location]) { + case Some(Location(uri)) ⇒ + (if (expectedUri.isAbsolute) uri else uri.toRelative) shouldEqual expectedUri + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RangeDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RangeDirectivesSpec.scala new file mode 100644 index 0000000000..dc678809ef --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RangeDirectivesSpec.scala @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.Await +import scala.concurrent.duration._ +import akka.http.scaladsl.model.StatusCodes._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.impl.util._ +import akka.stream.scaladsl.{ Sink, Source } +import akka.util.ByteString +import org.scalatest.{ Inside, Inspectors } + +class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { + lazy val wrs = + mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) & + withRangeSupport + + def bytes(length: Byte) = Array.tabulate[Byte](length)(_.toByte) + + "The `withRangeSupport` directive" should { + def completeWithRangedBytes(length: Byte) = wrs(complete(bytes(length))) + + "return an Accept-Ranges(bytes) header for GET requests" in { + Get() ~> { wrs { complete("any") } } ~> check { + headers should contain(`Accept-Ranges`(RangeUnits.Bytes)) + } + } + + "not return an Accept-Ranges(bytes) header for non-GET requests" in { + Put() ~> { wrs { complete("any") } } ~> check { + headers should not contain `Accept-Ranges`(RangeUnits.Bytes) + } + } + + "return a Content-Range header for a ranged request with a single range" in { + Get() ~> addHeader(Range(ByteRange(0, 1))) ~> completeWithRangedBytes(10) ~> check { + headers should contain(`Content-Range`(ContentRange(0, 1, 10))) + status shouldEqual PartialContent + responseAs[Array[Byte]] shouldEqual bytes(2) + } + } + + "return a partial response for a ranged request with a single range with undefined lastBytePosition" in { + Get() ~> addHeader(Range(ByteRange.fromOffset(5))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual Array[Byte](5, 6, 7, 8, 9) + } + } + + "return a partial response for a ranged request with a single suffix range" in { + Get() ~> addHeader(Range(ByteRange.suffix(1))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual Array[Byte](9) + } + } + + "return a partial response for a ranged request with a overlapping suffix range" in { + Get() ~> addHeader(Range(ByteRange.suffix(100))) ~> completeWithRangedBytes(10) ~> check { + responseAs[Array[Byte]] shouldEqual bytes(10) + } + } + + "be transparent to non-GET requests" in { + Post() ~> addHeader(Range(ByteRange(1, 2))) ~> completeWithRangedBytes(5) ~> check { + responseAs[Array[Byte]] shouldEqual bytes(5) + } + } + + "be transparent to non-200 responses" in { + Get() ~> addHeader(Range(ByteRange(1, 2))) ~> Route.seal(wrs(reject())) ~> check { + status == NotFound + headers.exists { case `Content-Range`(_, _) ⇒ true; case _ ⇒ false } shouldEqual false + } + } + + "reject an unsatisfiable single range" in { + Get() ~> addHeader(Range(ByteRange(100, 200))) ~> completeWithRangedBytes(10) ~> check { + rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange(100, 200)), 10) + } + } + + "reject an unsatisfiable single suffix range with length 0" in { + Get() ~> addHeader(Range(ByteRange.suffix(0))) ~> completeWithRangedBytes(42) ~> check { + rejection shouldEqual UnsatisfiableRangeRejection(Seq(ByteRange.suffix(0)), 42) + } + } + + "return a mediaType of 'multipart/byteranges' for a ranged request with multiple ranges" in { + Get() ~> addHeader(Range(ByteRange(0, 10), ByteRange(0, 10))) ~> completeWithRangedBytes(10) ~> check { + mediaType.withParams(Map.empty) shouldEqual MediaTypes.`multipart/byteranges` + } + } + + "return a 'multipart/byteranges' for a ranged request with multiple coalesced ranges and expect ranges in ascending order" in { + Get() ~> addHeader(Range(ByteRange(5, 10), ByteRange(0, 1), ByteRange(1, 2))) ~> { + wrs { complete("Some random and not super short entity.") } + } ~> check { + header[`Content-Range`] should be(None) + val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second) + parts.size shouldEqual 2 + inside(parts(0)) { + case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) ⇒ + range shouldEqual ContentRange.Default(0, 2, Some(39)) + unit shouldEqual RangeUnits.Bytes + Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "Som" + } + inside(parts(1)) { + case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) ⇒ + range shouldEqual ContentRange.Default(5, 10, Some(39)) + unit shouldEqual RangeUnits.Bytes + Await.result(entity.dataBytes.utf8String, 100.millis) shouldEqual "random" + } + } + } + + "return a 'multipart/byteranges' for a ranged request with multiple ranges if entity data source isn't reusable" in { + val content = "Some random and not super short entity." + def entityData() = StreamUtils.oneTimeSource(Source.single(ByteString(content))) + + Get() ~> addHeader(Range(ByteRange(5, 10), ByteRange(0, 1), ByteRange(1, 2))) ~> { + wrs { complete(HttpEntity.Default(MediaTypes.`text/plain`, content.length, entityData())) } + } ~> check { + header[`Content-Range`] should be(None) + val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second) + parts.size shouldEqual 2 + } + } + + "reject a request with too many requested ranges" in { + val ranges = (1 to 20).map(a ⇒ ByteRange.fromOffset(a)) + Get() ~> addHeader(Range(ranges)) ~> completeWithRangedBytes(100) ~> check { + rejection shouldEqual TooManyRangesRejection(10) + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RespondWithDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RespondWithDirectivesSpec.scala new file mode 100644 index 0000000000..0287aa2f60 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RespondWithDirectivesSpec.scala @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import akka.http.scaladsl.model._ +import MediaTypes._ +import headers._ +import StatusCodes._ + +import akka.http.scaladsl.server._ + +class RespondWithDirectivesSpec extends RoutingSpec { + + "overrideStatusCode" should { + "set the given status on successful responses" in { + Get() ~> { + overrideStatusCode(Created) { completeOk } + } ~> check { response shouldEqual HttpResponse(Created) } + } + "leave rejections unaffected" in { + Get() ~> { + overrideStatusCode(Created) { reject } + } ~> check { rejections shouldEqual Nil } + } + } + + val customHeader = RawHeader("custom", "custom") + val customHeader2 = RawHeader("custom2", "custom2") + val existingHeader = RawHeader("custom", "existing") + + "respondWithHeader" should { + val customHeader = RawHeader("custom", "custom") + "add the given header to successful responses" in { + Get() ~> { + respondWithHeader(customHeader) { completeOk } + } ~> check { response shouldEqual HttpResponse(headers = customHeader :: Nil) } + } + } + "respondWithHeaders" should { + "add the given headers to successful responses" in { + Get() ~> { + respondWithHeaders(customHeader, customHeader2) { completeOk } + } ~> check { response shouldEqual HttpResponse(headers = customHeader :: customHeader2 :: Nil) } + } + } + "respondWithDefaultHeader" should { + def route(extraHeaders: HttpHeader*) = respondWithDefaultHeader(customHeader) { + respondWithHeaders(extraHeaders: _*) { + completeOk + } + } + + "add the given header to a response if the header was missing before" in { + Get() ~> route() ~> check { response shouldEqual HttpResponse(headers = customHeader :: Nil) } + } + "not change a response if the header already existed" in { + Get() ~> route(existingHeader) ~> check { response shouldEqual HttpResponse(headers = existingHeader :: Nil) } + } + } + "respondWithDefaultHeaders" should { + def route(extraHeaders: HttpHeader*) = respondWithDefaultHeaders(customHeader, customHeader2) { + respondWithHeaders(extraHeaders: _*) { + completeOk + } + } + + "add the given headers to a response if the header was missing before" in { + Get() ~> route() ~> check { response shouldEqual HttpResponse(headers = customHeader :: customHeader2 :: Nil) } + } + "not update an existing header" in { + Get() ~> route(existingHeader) ~> check { + response shouldEqual HttpResponse(headers = List(customHeader2, existingHeader)) + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala new file mode 100644 index 0000000000..fea5a6cf16 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import org.scalatest.FreeSpec + +import scala.concurrent.Promise +import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._ +import akka.http.scaladsl.marshalling._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.model._ +import akka.http.impl.util._ +import headers._ +import StatusCodes._ +import MediaTypes._ + +class RouteDirectivesSpec extends FreeSpec with GenericRoutingSpec { + + "The `complete` directive should" - { + "by chainable with the `&` operator" in { + Get() ~> (get & complete("yeah")) ~> check { responseAs[String] shouldEqual "yeah" } + } + "be lazy in its argument evaluation, independently of application style" in { + var i = 0 + Put() ~> { + get { complete { i += 1; "get" } } ~ + put { complete { i += 1; "put" } } ~ + (post & complete { i += 1; "post" }) + } ~> check { + responseAs[String] shouldEqual "put" + i shouldEqual 1 + } + } + "support completion from response futures" - { + "simple case without marshaller" in { + Get() ~> { + get & complete(Promise.successful(HttpResponse(entity = "yup")).future) + } ~> check { responseAs[String] shouldEqual "yup" } + } + "for successful futures and marshalling" in { + Get() ~> complete(Promise.successful("yes").future) ~> check { responseAs[String] shouldEqual "yes" } + } + "for failed futures and marshalling" in { + object TestException extends RuntimeException + Get() ~> complete(Promise.failed[String](TestException).future) ~> + check { + status shouldEqual StatusCodes.InternalServerError + responseAs[String] shouldEqual "There was an internal server error." + } + } + "for futures failed with a RejectionError" in { + Get() ~> complete(Promise.failed[String](RejectionError(AuthorizationFailedRejection)).future) ~> + check { + rejection shouldEqual AuthorizationFailedRejection + } + } + } + "allow easy handling of futured ToResponseMarshallers" in pending /*{ + trait RegistrationStatus + case class Registered(name: String) extends RegistrationStatus + case object AlreadyRegistered extends RegistrationStatus + + val route = + get { + path("register" / Segment) { name ⇒ + def registerUser(name: String): Future[RegistrationStatus] = Future.successful { + name match { + case "otto" ⇒ AlreadyRegistered + case _ ⇒ Registered(name) + } + } + complete { + registerUser(name).map[ToResponseMarshallable] { + case Registered(_) ⇒ HttpEntity.Empty + case AlreadyRegistered ⇒ + import spray.json.DefaultJsonProtocol._ + import spray.httpx.SprayJsonSupport._ + (StatusCodes.BadRequest, Map("error" -> "User already Registered")) + } + } + } + } + + Get("/register/otto") ~> route ~> check { + status shouldEqual StatusCodes.BadRequest + } + Get("/register/karl") ~> route ~> check { + status shouldEqual StatusCodes.OK + entity shouldEqual HttpEntity.Empty + } + }*/ + "do Content-Type negotiation for multi-marshallers" in pendingUntilFixed { + val route = get & complete(Data("Ida", 83)) + + import akka.http.scaladsl.model.headers.Accept + Get().withHeaders(Accept(MediaTypes.`application/json`)) ~> route ~> check { + responseAs[String] shouldEqual + """{ + | "name": "Ida", + | "age": 83 + |}""".stripMarginWithNewline("\n") + } + Get().withHeaders(Accept(MediaTypes.`text/xml`)) ~> route ~> check { + responseAs[xml.NodeSeq] shouldEqual Ida83 + } + pending + /*Get().withHeaders(Accept(MediaTypes.`text/plain`)) ~> HttpService.sealRoute(route) ~> check { + status shouldEqual StatusCodes.NotAcceptable + }*/ + } + } + + "the redirect directive should" - { + "produce proper 'Found' redirections" in { + Get() ~> { + redirect("/foo", Found) + } ~> check { + response shouldEqual HttpResponse( + status = 302, + entity = HttpEntity(`text/html`, "The requested resource temporarily resides under this URI."), + headers = Location("/foo") :: Nil) + } + } + + "produce proper 'NotModified' redirections" in { + Get() ~> { + redirect("/foo", NotModified) + } ~> check { response shouldEqual HttpResponse(304, headers = Location("/foo") :: Nil) } + } + } + + case class Data(name: String, age: Int) + object Data { + //import spray.json.DefaultJsonProtocol._ + //import spray.httpx.SprayJsonSupport._ + + val jsonMarshaller: ToEntityMarshaller[Data] = FIXME // jsonFormat2(Data.apply) + val xmlMarshaller: ToEntityMarshaller[Data] = FIXME + /*Marshaller.delegate[Data, xml.NodeSeq](MediaTypes.`text/xml`) { (data: Data) ⇒ + { data.name }{ data.age } + }*/ + + implicit val dataMarshaller: ToResponseMarshaller[Data] = FIXME + //ToResponseMarshaller.oneOf(MediaTypes.`application/json`, MediaTypes.`text/xml`)(jsonMarshaller, xmlMarshaller) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SchemeDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SchemeDirectivesSpec.scala new file mode 100644 index 0000000000..df08e5861f --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SchemeDirectivesSpec.scala @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.StatusCodes._ + +class SchemeDirectivesSpec extends RoutingSpec { + "the extractScheme directive" should { + "extract the Uri scheme" in { + Put("http://localhost/", "Hello") ~> extractScheme { echoComplete } ~> check { responseAs[String] shouldEqual "http" } + } + } + + """the scheme("http") directive""" should { + "let requests with an http Uri scheme pass" in { + Put("http://localhost/", "Hello") ~> scheme("http") { completeOk } ~> check { response shouldEqual Ok } + } + "reject requests with an https Uri scheme" in { + Get("https://localhost/") ~> scheme("http") { completeOk } ~> check { rejections shouldEqual List(SchemeRejection("http")) } + } + "cancel SchemeRejection if other scheme passed" in { + val route = + scheme("https") { completeOk } ~ + scheme("http") { reject } + + Put("http://localhost/", "Hello") ~> route ~> check { + rejections should be(Nil) + } + } + } + + """the scheme("https") directive""" should { + "let requests with an https Uri scheme pass" in { + Put("https://localhost/", "Hello") ~> scheme("https") { completeOk } ~> check { response shouldEqual Ok } + } + "reject requests with an http Uri scheme" in { + Get("http://localhost/") ~> scheme("https") { completeOk } ~> check { rejections shouldEqual List(SchemeRejection("https")) } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala new file mode 100644 index 0000000000..f19ecd2251 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.Future +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.AuthenticationFailedRejection.{ CredentialsRejected, CredentialsMissing } + +class SecurityDirectivesSpec extends RoutingSpec { + val dontAuth = authenticateBasicAsync[String]("MyRealm", _ ⇒ Future.successful(None)) + val doAuth = authenticateBasicPF("MyRealm", { case UserCredentials.Provided(name) ⇒ name }) + val authWithAnonymous = doAuth.withAnonymousUser("We are Legion") + + val challenge = HttpChallenge("Basic", "MyRealm") + + "basic authentication" should { + "reject requests without Authorization header with an AuthenticationFailedRejection" in { + Get() ~> { + dontAuth { echoComplete } + } ~> check { rejection shouldEqual AuthenticationFailedRejection(CredentialsMissing, challenge) } + } + "reject unauthenticated requests with Authorization header with an AuthenticationFailedRejection" in { + Get() ~> Authorization(BasicHttpCredentials("Bob", "")) ~> { + dontAuth { echoComplete } + } ~> check { rejection shouldEqual AuthenticationFailedRejection(CredentialsRejected, challenge) } + } + "reject requests with illegal Authorization header with 401" in { + Get() ~> RawHeader("Authorization", "bob alice") ~> Route.seal { + dontAuth { echoComplete } + } ~> check { + status shouldEqual StatusCodes.Unauthorized + responseAs[String] shouldEqual "The resource requires authentication, which was not supplied with the request" + header[`WWW-Authenticate`] shouldEqual Some(`WWW-Authenticate`(challenge)) + } + } + "extract the object representing the user identity created by successful authentication" in { + Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { + doAuth { echoComplete } + } ~> check { responseAs[String] shouldEqual "Alice" } + } + "extract the object representing the user identity created for the anonymous user" in { + Get() ~> { + authWithAnonymous { echoComplete } + } ~> check { responseAs[String] shouldEqual "We are Legion" } + } + "properly handle exceptions thrown in its inner route" in { + object TestException extends RuntimeException + Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { + Route.seal { + doAuth { _ ⇒ throw TestException } + } + } ~> check { status shouldEqual StatusCodes.InternalServerError } + } + } + "authentication directives" should { + "properly stack" in { + val otherChallenge = HttpChallenge("MyAuth", "MyRealm2") + val otherAuth: Directive1[String] = authenticateOrRejectWithChallenge { (cred: Option[HttpCredentials]) ⇒ + Future.successful(Left(otherChallenge)) + } + val bothAuth = dontAuth | otherAuth + + Get() ~> Route.seal(bothAuth { echoComplete }) ~> check { + status shouldEqual StatusCodes.Unauthorized + headers.collect { + case `WWW-Authenticate`(challenge +: Nil) ⇒ challenge + } shouldEqual Seq(challenge, otherChallenge) + } + } + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala new file mode 100644 index 0000000000..5331fd0620 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl.server.directives + +import scala.collection.immutable.Seq +import akka.http.impl.engine.ws.InternalCustomHeader +import akka.http.scaladsl.model.headers.{ UpgradeProtocol, Upgrade } +import akka.http.scaladsl.model.{ HttpRequest, StatusCodes, HttpResponse } +import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.scaladsl.server.{ Route, RoutingSpec } +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +class WebsocketDirectivesSpec extends RoutingSpec { + "the handleWebsocketMessages directive" should { + "handle websocket requests" in { + Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~> + emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> + check { + status shouldEqual StatusCodes.SwitchingProtocols + } + } + "reject non-websocket requests" in { + Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check { + status shouldEqual StatusCodes.BadRequest + responseAs[String] shouldEqual "Expected Websocket Upgrade request" + } + } + } + + /** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */ + def emulateHttpCore(req: HttpRequest): HttpRequest = + req.header[Upgrade] match { + case Some(upgrade) if upgrade.hasWebsocket ⇒ req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock) + case _ ⇒ req + } + def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = + new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { + def requestedProtocols: Seq[String] = Nil + + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = + HttpResponse(StatusCodes.SwitchingProtocols) + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/util/TupleOpsSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/util/TupleOpsSpec.scala new file mode 100644 index 0000000000..f1a36af50f --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/util/TupleOpsSpec.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import org.scalatest.{ Matchers, WordSpec } + +class TupleOpsSpec extends WordSpec with Matchers { + import TupleOps._ + + "The TupleOps" should { + + "support folding over tuples using a binary poly-function" in { + object Funky extends BinaryPolyFunc { + implicit def step1 = at[Double, Int](_ + _) + implicit def step2 = at[Double, Symbol]((d, s) ⇒ (d + s.name.tail.toInt).toByte) + implicit def step3 = at[Byte, String]((byte, s) ⇒ byte + s.toLong) + } + (1, 'X2, "3").foldLeft(0.0)(Funky) shouldEqual 6L + } + + "support joining tuples" in { + (1, 'X2, "3") join () shouldEqual (1, 'X2, "3") + () join (1, 'X2, "3") shouldEqual (1, 'X2, "3") + (1, 'X2, "3") join (4.0, 5L) shouldEqual (1, 'X2, "3", 4.0, 5L) + } + } +} \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/UnmarshallingSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/UnmarshallingSpec.scala new file mode 100644 index 0000000000..71dd8b9f5b --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/unmarshalling/UnmarshallingSpec.scala @@ -0,0 +1,289 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import scala.concurrent.duration._ +import scala.concurrent.{ Future, Await } +import org.scalatest.matchers.Matcher +import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers } +import akka.http.scaladsl.testkit.ScalatestUtils +import akka.util.ByteString +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture._ +import akka.http.impl.util._ +import headers._ +import MediaTypes._ + +class UnmarshallingSpec extends FreeSpec with Matchers with BeforeAndAfterAll with ScalatestUtils { + implicit val system = ActorSystem(getClass.getSimpleName) + implicit val materializer = ActorFlowMaterializer() + import system.dispatcher + + "The PredefinedFromEntityUnmarshallers." - { + "stringUnmarshaller should unmarshal `text/plain` content in UTF-8 to Strings" in { + Unmarshal(HttpEntity("Hällö")).to[String] should evaluateTo("Hällö") + } + "charArrayUnmarshaller should unmarshal `text/plain` content in UTF-8 to char arrays" in { + Unmarshal(HttpEntity("árvíztűrő ütvefúrógép")).to[Array[Char]] should evaluateTo("árvíztűrő ütvefúrógép".toCharArray) + } + } + + "The MultipartUnmarshallers." - { + + "multipartGeneralUnmarshaller should correctly unmarshal 'multipart/*' content with" - { + "an empty part" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", + """--XYZABC + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity.empty(ContentTypes.`text/plain(UTF-8)`))) + } + "two empty parts" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", + """--XYZABC + |--XYZABC + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity.empty(ContentTypes.`text/plain(UTF-8)`)), + Multipart.General.BodyPart.Strict(HttpEntity.empty(ContentTypes.`text/plain(UTF-8)`))) + } + "a part without entity and missing header separation CRLF" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", + """--XYZABC + |Content-type: text/xml + |Age: 12 + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity.empty(MediaTypes.`text/xml`), List(Age(12)))) + } + "an implicitly typed part (without headers)" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", + """--XYZABC + | + |Perfectly fine part content. + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity(ContentTypes.`text/plain(UTF-8)`, "Perfectly fine part content."))) + } + "one non-empty form-data part" in { + Unmarshal(HttpEntity(`multipart/form-data` withBoundary "-", + """--- + |Content-type: text/plain; charset=UTF8 + |content-disposition: form-data; name="email" + | + |test@there.com + |-----""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict( + HttpEntity(ContentTypes.`text/plain(UTF-8)`, "test@there.com"), + List(`Content-Disposition`(ContentDispositionTypes.`form-data`, Map("name" -> "email"))))) + } + "two different parts" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "12345", + """--12345 + | + |first part, with a trailing newline + | + |--12345 + |Content-Type: application/octet-stream + |Content-Transfer-Encoding: binary + | + |filecontent + |--12345--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity(ContentTypes.`text/plain(UTF-8)`, "first part, with a trailing newline\r\n")), + Multipart.General.BodyPart.Strict( + HttpEntity(`application/octet-stream`, "filecontent"), + List(RawHeader("Content-Transfer-Encoding", "binary")))) + } + "illegal headers" in ( + Unmarshal(HttpEntity(`multipart/form-data` withBoundary "XYZABC", + """--XYZABC + |Date: unknown + |content-disposition: form-data; name=email + | + |test@there.com + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict( + HttpEntity(ContentTypes.`text/plain(UTF-8)`, "test@there.com"), + List(`Content-Disposition`(ContentDispositionTypes.`form-data`, Map("name" -> "email")), + RawHeader("date", "unknown"))))) + "a full example (Strict)" in { + Unmarshal(HttpEntity(`multipart/mixed` withBoundary "12345", + """preamble and + |more preamble + |--12345 + | + |first part, implicitly typed + |--12345 + |Content-Type: application/octet-stream + | + |second part, explicitly typed + |--12345-- + |epilogue and + |more epilogue""".stripMarginWithNewline("\r\n"))).to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity(ContentTypes.`text/plain(UTF-8)`, "first part, implicitly typed")), + Multipart.General.BodyPart.Strict(HttpEntity(`application/octet-stream`, "second part, explicitly typed"))) + } + "a full example (Default)" in { + val content = """preamble and + |more preamble + |--12345 + | + |first part, implicitly typed + |--12345 + |Content-Type: application/octet-stream + | + |second part, explicitly typed + |--12345-- + |epilogue and + |more epilogue""".stripMarginWithNewline("\r\n") + val byteStrings = content.map(c ⇒ ByteString(c.toString)) // one-char ByteStrings + Unmarshal(HttpEntity.Default(`multipart/mixed` withBoundary "12345", content.length, Source(byteStrings))) + .to[Multipart.General] should haveParts( + Multipart.General.BodyPart.Strict(HttpEntity(ContentTypes.`text/plain(UTF-8)`, "first part, implicitly typed")), + Multipart.General.BodyPart.Strict(HttpEntity(`application/octet-stream`, "second part, explicitly typed"))) + } + } + + "multipartGeneralUnmarshaller should reject illegal multipart content with" - { + "an empty entity" in { + Await.result(Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", ByteString.empty)) + .to[Multipart.General].failed, 1.second).getMessage shouldEqual "Unexpected end of multipart entity" + } + "an entity without initial boundary" in { + Await.result(Unmarshal(HttpEntity(`multipart/mixed` withBoundary "XYZABC", + """this is + |just preamble text""".stripMarginWithNewline("\r\n"))) + .to[Multipart.General].failed, 1.second).getMessage shouldEqual "Unexpected end of multipart entity" + } + "a stray boundary" in { + Await.result(Unmarshal(HttpEntity(`multipart/form-data` withBoundary "ABC", + """--ABC + |Content-type: text/plain; charset=UTF8 + |--ABCContent-type: application/json + |content-disposition: form-data; name="email" + |-----""".stripMarginWithNewline("\r\n"))) + .to[Multipart.General].failed, 1.second).getMessage shouldEqual "Illegal multipart boundary in message content" + } + "duplicate Content-Type header" in { + Await.result(Unmarshal(HttpEntity(`multipart/form-data` withBoundary "-", + """--- + |Content-type: text/plain; charset=UTF8 + |Content-type: application/json + |content-disposition: form-data; name="email" + | + |test@there.com + |-----""".stripMarginWithNewline("\r\n"))) + .to[Multipart.General].failed, 1.second).getMessage shouldEqual + "multipart part must not contain more than one Content-Type header" + } + "a missing header-separating CRLF (in Strict entity)" in { + Await.result(Unmarshal(HttpEntity(`multipart/form-data` withBoundary "-", + """--- + |not good here + |-----""".stripMarginWithNewline("\r\n"))) + .to[Multipart.General].failed, 1.second).getMessage shouldEqual "Illegal character ' ' in header name" + } + "a missing header-separating CRLF (in Default entity)" in { + val content = """--- + | + |ok + |--- + |not ok + |-----""".stripMarginWithNewline("\r\n") + val byteStrings = content.map(c ⇒ ByteString(c.toString)) // one-char ByteStrings + val contentType = `multipart/form-data` withBoundary "-" + Await.result(Unmarshal(HttpEntity.Default(contentType, content.length, Source(byteStrings))) + .to[Multipart.General] + .flatMap(_ toStrict 1.second).failed, 1.second).getMessage shouldEqual "Illegal character ' ' in header name" + } + } + + "multipartByteRangesUnmarshaller should correctly unmarshal multipart/byteranges content with two different parts" in { + Unmarshal(HttpEntity(`multipart/byteranges` withBoundary "12345", + """--12345 + |Content-Range: bytes 0-2/26 + |Content-Type: text/plain + | + |ABC + |--12345 + |Content-Range: bytes 23-25/26 + |Content-Type: text/plain + | + |XYZ + |--12345--""".stripMarginWithNewline("\r\n"))).to[Multipart.ByteRanges] should haveParts( + Multipart.ByteRanges.BodyPart.Strict(ContentRange(0, 2, 26), HttpEntity(ContentTypes.`text/plain`, "ABC")), + Multipart.ByteRanges.BodyPart.Strict(ContentRange(23, 25, 26), HttpEntity(ContentTypes.`text/plain`, "XYZ"))) + } + + "multipartFormDataUnmarshaller should correctly unmarshal 'multipart/form-data' content" - { + "with one element" in { + Unmarshal(HttpEntity(`multipart/form-data` withBoundary "XYZABC", + """--XYZABC + |content-disposition: form-data; name=email + | + |test@there.com + |--XYZABC--""".stripMarginWithNewline("\r\n"))).to[Multipart.FormData] should haveParts( + Multipart.FormData.BodyPart.Strict("email", HttpEntity(ContentTypes.`application/octet-stream`, "test@there.com"))) + } + "with a file" in { + Unmarshal { + HttpEntity.Default( + contentType = `multipart/form-data` withBoundary "XYZABC", + contentLength = 1, // not verified during unmarshalling + data = Source { + List( + ByteString { + """--XYZABC + |Content-Disposition: form-data; name="email" + | + |test@there.com + |--XYZABC + |Content-Dispo""".stripMarginWithNewline("\r\n") + }, + ByteString { + """sition: form-data; name="userfile"; filename="test.dat" + |Content-Type: application/pdf + |Content-Transfer-Encoding: binary + | + |filecontent + |--XYZABC--""".stripMarginWithNewline("\r\n") + }) + }) + }.to[Multipart.FormData].flatMap(_.toStrict(1.second)) should haveParts( + Multipart.FormData.BodyPart.Strict("email", HttpEntity(ContentTypes.`application/octet-stream`, "test@there.com")), + Multipart.FormData.BodyPart.Strict("userfile", HttpEntity(MediaTypes.`application/pdf`, "filecontent"), + Map("filename" -> "test.dat"), List(RawHeader("Content-Transfer-Encoding", "binary")))) + } + // TODO: reactivate after multipart/form-data unmarshalling integrity verification is implemented + // + // "reject illegal multipart content" in { + // val Left(MalformedContent(msg, _)) = HttpEntity(`multipart/form-data` withBoundary "XYZABC", "--noboundary--").as[MultipartFormData] + // msg shouldEqual "Missing start boundary" + // } + // "reject illegal form-data content" in { + // val Left(MalformedContent(msg, _)) = HttpEntity(`multipart/form-data` withBoundary "XYZABC", + // """|--XYZABC + // |content-disposition: form-data; named="email" + // | + // |test@there.com + // |--XYZABC--""".stripMargin).as[MultipartFormData] + // msg shouldEqual "Illegal multipart/form-data content: unnamed body part (no Content-Disposition header or no 'name' parameter)" + // } + } + } + + override def afterAll() = system.shutdown() + + def haveParts[T <: Multipart](parts: Multipart.BodyPart.Strict*): Matcher[Future[T]] = + equal(parts).matcher[Seq[Multipart.BodyPart.Strict]] compose { x ⇒ + Await.result(x + .fast.flatMap { + _.parts + .mapAsync(1)(_ toStrict 1.second) + .grouped(100) + .runWith(Sink.head) + } + .fast.recover { case _: NoSuchElementException ⇒ Nil }, 1.second) + } +} diff --git a/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ApplyConverterInstances.scala.template b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ApplyConverterInstances.scala.template new file mode 100644 index 0000000000..fb20bd11ef --- /dev/null +++ b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ApplyConverterInstances.scala.template @@ -0,0 +1,17 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import akka.http.scaladsl.server.Route + +private[util] abstract class ApplyConverterInstances { + [#implicit def hac1[[#T1#]] = new ApplyConverter[Tuple1[[#T1#]]] { + type In = ([#T1#]) ⇒ Route + def apply(fn: In): (Tuple1[[#T1#]]) ⇒ Route = { + case Tuple1([#t1#]) ⇒ fn([#t1#]) + } + }# + ] +} \ No newline at end of file diff --git a/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ConstructFromTupleInstances.scala.template b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ConstructFromTupleInstances.scala.template new file mode 100644 index 0000000000..3e9428f5ee --- /dev/null +++ b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/ConstructFromTupleInstances.scala.template @@ -0,0 +1,13 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +private[util] abstract class ConstructFromTupleInstances { + [#implicit def instance1[[#T1#], R](construct: ([#T1#]) => R): ConstructFromTuple[Tuple1[[#T1#]], R] = + new ConstructFromTuple[Tuple1[[#T1#]], R] { + def apply(tup: Tuple1[[#T1#]]): R = construct([#tup._1#]) + }# + ] +} diff --git a/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleAppendOneInstances.scala.template b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleAppendOneInstances.scala.template new file mode 100644 index 0000000000..025f7f347e --- /dev/null +++ b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleAppendOneInstances.scala.template @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import TupleOps.AppendOne + +private[util] abstract class TupleAppendOneInstances { + type Aux[P, S, Out0] = AppendOne[P, S] { type Out = Out0 } + + implicit def append0[T1]: Aux[Unit, T1, Tuple1[T1]] = + new AppendOne[Unit, T1] { + type Out = Tuple1[T1] + def apply(prefix: Unit, last: T1): Tuple1[T1] = Tuple1(last) + } + + [1..21#implicit def append1[[#T1#], L]: Aux[Tuple1[[#T1#]], L, Tuple2[[#T1#], L]] = + new AppendOne[Tuple1[[#T1#]], L] { + type Out = Tuple2[[#T1#], L] + def apply(prefix: Tuple1[[#T1#]], last: L): Tuple2[[#T1#], L] = Tuple2([#prefix._1#], last) + }# + ] +} \ No newline at end of file diff --git a/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleFoldInstances.scala.template b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleFoldInstances.scala.template new file mode 100644 index 0000000000..6b82066420 --- /dev/null +++ b/akka-http/src/main/boilerplate/akka/http/scaladsl/server/util/TupleFoldInstances.scala.template @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import TupleOps.FoldLeft +import BinaryPolyFunc.Case + +private[util] abstract class TupleFoldInstances { + + type Aux[In, T, Op, Out0] = FoldLeft[In, T, Op] { type Out = Out0 } + + implicit def t0[In, Op]: Aux[In, Unit, Op, In] = + new FoldLeft[In, Unit, Op] { + type Out = In + def apply(zero: In, tuple: Unit) = zero + } + + implicit def t1[In, A, Op](implicit f: Case[In, A, Op]): Aux[In, Tuple1[A], Op, f.Out] = + new FoldLeft[In, Tuple1[A], Op] { + type Out = f.Out + def apply(zero: In, tuple: Tuple1[A]) = f(zero, tuple._1) + } + + [2..22#implicit def t1[In, [2..#T0#], X, T1, Op](implicit fold: Aux[In, Tuple0[[2..#T0#]], Op, X], f: Case[X, T1, Op]): Aux[In, Tuple1[[#T1#]], Op, f.Out] = + new FoldLeft[In, Tuple1[[#T1#]], Op] { + type Out = f.Out + def apply(zero: In, t: Tuple1[[#T1#]]) = + f(fold(zero, Tuple0([2..#t._0#])), t._1) + }# + ] +} \ No newline at end of file diff --git a/akka-http/src/main/java/akka/http/javadsl/server/AbstractDirective.java b/akka-http/src/main/java/akka/http/javadsl/server/AbstractDirective.java new file mode 100644 index 0000000000..dbc55c82c6 --- /dev/null +++ b/akka-http/src/main/java/akka/http/javadsl/server/AbstractDirective.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +/** + * Helper class to steer around SI-9013. + * + * It's currently impossible to implement a trait containing @varargs methods + * if the trait is written in Scala. Therefore, derive from this class and + * implement the method without varargs. + * FIXME: remove once SI-9013 is fixed. + * + * See https://issues.scala-lang.org/browse/SI-9013 + */ +abstract class AbstractDirective implements Directive { + @Override + public Route route(Route first, Route... others) { + return createRoute(first, others); + } + + protected abstract Route createRoute(Route first, Route[] others); +} diff --git a/akka-http/src/main/java/akka/http/javadsl/server/Coder.java b/akka-http/src/main/java/akka/http/javadsl/server/Coder.java new file mode 100644 index 0000000000..2461474c71 --- /dev/null +++ b/akka-http/src/main/java/akka/http/javadsl/server/Coder.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + +import akka.http.scaladsl.coding.Deflate$; +import akka.http.scaladsl.coding.Gzip$; +import akka.http.scaladsl.coding.NoCoding$; +import akka.stream.FlowMaterializer; +import akka.util.ByteString; +import scala.concurrent.Future; + +/** + * A coder is an implementation of the predefined encoders/decoders defined for HTTP. + */ +public enum Coder { + NoCoding(NoCoding$.MODULE$), Deflate(Deflate$.MODULE$), Gzip(Gzip$.MODULE$); + + private akka.http.scaladsl.coding.Coder underlying; + + Coder(akka.http.scaladsl.coding.Coder underlying) { + this.underlying = underlying; + } + + public ByteString encode(ByteString input) { + return underlying.encode(input); + } + public Future decode(ByteString input, FlowMaterializer mat) { + return underlying.decode(input, mat); + } + public akka.http.scaladsl.coding.Coder _underlyingScalaCoder() { + return underlying; + } +} diff --git a/akka-http/src/main/java/akka/http/javadsl/server/Directive.java b/akka-http/src/main/java/akka/http/javadsl/server/Directive.java new file mode 100644 index 0000000000..8612754f03 --- /dev/null +++ b/akka-http/src/main/java/akka/http/javadsl/server/Directive.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server; + + +/** + * A directive is the basic building block for building routes by composing + * any kind of request or response processing into the main route a request + * flows through. It is a factory that creates a route when given a sequence of + * route alternatives to be augmented with the function the directive + * represents. + * + * The `path`-Directive, for example, filters incoming requests by checking if + * the URI of the incoming request matches the pattern and only invokes its inner + * routes for those requests. + */ +public interface Directive { + /** + * Creates the Route given a sequence of inner route alternatives. + */ + Route route(Route first, Route... others); +} diff --git a/akka-http/src/main/resources/reference.conf b/akka-http/src/main/resources/reference.conf new file mode 100644 index 0000000000..eba81d60a1 --- /dev/null +++ b/akka-http/src/main/resources/reference.conf @@ -0,0 +1,43 @@ +####################################### +# akka-http Reference Config File # +####################################### + +# This is the reference config file that contains all the default settings. +# Make your edits/overrides in your application.conf. + +akka.http.routing { + # Enables/disables the returning of more detailed error messages to the + # client in the error response + # Should be disabled for browser-facing APIs due to the risk of XSS attacks + # and (probably) enabled for internal or non-browser APIs + # (Note that akka-http will always produce log messages containing the full error details) + verbose-error-messages = off + + # Enables/disables ETag and `If-Modified-Since` support for FileAndResourceDirectives + file-get-conditional = on + + # Enables/disables the rendering of the "rendered by" footer in directory listings + render-vanity-footer = yes + + # The maximum size between two requested ranges. Ranges with less space in between will be coalesced. + # + # When multiple ranges are requested, a server may coalesce any of the ranges that overlap or that are separated + # by a gap that is smaller than the overhead of sending multiple parts, regardless of the order in which the + # corresponding byte-range-spec appeared in the received Range header field. Since the typical overhead between + # parts of a multipart/byteranges payload is around 80 bytes, depending on the selected representation's + # media type and the chosen boundary parameter length, it can be less efficient to transfer many small + # disjoint parts than it is to transfer the entire selected representation. + range-coalescing-threshold = 80 + + # The maximum number of allowed ranges per request. + # Requests with more ranges will be rejected due to DOS suspicion. + range-count-limit = 16 + + # The maximum number of bytes per ByteString a decoding directive will produce + # for an entity data stream. + decode-max-bytes-per-chunk = 1m + + # Fully qualified config path which holds the dispatcher configuration + # to be used by FlowMaterialiser when creating Actors for IO operations. + file-io-dispatcher = ${akka.stream.file-io-dispatcher} +} diff --git a/akka-http/src/main/scala/akka/http/impl/server/ExtractionImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/ExtractionImpl.scala new file mode 100644 index 0000000000..a9d88a4e13 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/ExtractionImpl.scala @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.reflect.ClassTag +import akka.http.javadsl.server.{ RequestContext, RequestVal } +import akka.http.impl.util.JavaMapping.Implicits._ + +/** + * INTERNAL API + */ +private[http] trait ExtractionImplBase[T] extends RequestVal[T] { + protected[http] implicit def classTag: ClassTag[T] + def resultClass: Class[T] = classTag.runtimeClass.asInstanceOf[Class[T]] + + def get(ctx: RequestContext): T = + ctx.request.asScala.header[ExtractionMap].flatMap(_.get(this)) + .getOrElse(throw new RuntimeException(s"Value wasn't extracted! $this")) +} + +private[http] abstract class ExtractionImpl[T](implicit val classTag: ClassTag[T]) extends ExtractionImplBase[T] \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/impl/server/MarshallerImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/MarshallerImpl.scala new file mode 100644 index 0000000000..3690ca5fc2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/MarshallerImpl.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.concurrent.ExecutionContext +import akka.http.javadsl.server.Marshaller +import akka.http.scaladsl.marshalling + +/** + * INTERNAL API + */ +// FIXME: too lenient visibility, currently used to implement Java marshallers, needs proper API, see #16439 +case class MarshallerImpl[T](scalaMarshaller: ExecutionContext ⇒ marshalling.ToResponseMarshaller[T]) extends Marshaller[T] diff --git a/akka-http/src/main/scala/akka/http/impl/server/ParameterImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/ParameterImpl.scala new file mode 100644 index 0000000000..5beb3b5317 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/ParameterImpl.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import akka.http.javadsl.server.Parameter +import akka.http.scaladsl.server.directives.{ ParameterDirectives, BasicDirectives } +import akka.http.scaladsl.server.Directive1 +import akka.http.scaladsl.server.directives.ParameterDirectives.ParamMagnet + +/** + * INTERNAL API + */ +private[http] class ParameterImpl[T: ClassTag](val underlying: ExecutionContext ⇒ ParamMagnet { type Out = Directive1[T] }) + extends StandaloneExtractionImpl[T] with Parameter[T] { + + //def extract(ctx: RequestContext): Future[T] = + def directive: Directive1[T] = + BasicDirectives.extractExecutionContext.flatMap { implicit ec ⇒ + ParameterDirectives.parameter(underlying(ec)) + } +} diff --git a/akka-http/src/main/scala/akka/http/impl/server/PathMatcherImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/PathMatcherImpl.scala new file mode 100644 index 0000000000..813ba6322c --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/PathMatcherImpl.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.reflect.ClassTag +import akka.http.javadsl.server.PathMatcher +import akka.http.scaladsl.server.{ PathMatcher ⇒ ScalaPathMatcher } + +/** + * INTERNAL API + */ +private[http] class PathMatcherImpl[T: ClassTag](val matcher: ScalaPathMatcher[Tuple1[T]]) + extends ExtractionImpl[T] with PathMatcher[T] \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/impl/server/RequestContextImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/RequestContextImpl.scala new file mode 100644 index 0000000000..69b02b31ca --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/RequestContextImpl.scala @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.concurrent.Future +import akka.http.javadsl.{ model ⇒ jm } +import akka.http.impl.util.JavaMapping.Implicits._ +import akka.http.scaladsl.server.{ RequestContext ⇒ ScalaRequestContext } +import akka.http.javadsl.server._ + +/** + * INTERNAL API + */ +private[http] final case class RequestContextImpl(underlying: ScalaRequestContext) extends RequestContext { + import underlying.executionContext + + // provides auto-conversion to japi.RouteResult + import RouteResultImpl._ + + def request: jm.HttpRequest = underlying.request + def unmatchedPath: String = underlying.unmatchedPath.toString + + def completeWith(futureResult: Future[RouteResult]): RouteResult = + futureResult.flatMap { + case r: RouteResultImpl ⇒ r.underlying + } + def complete(text: String): RouteResult = underlying.complete(text) + def completeWithStatus(statusCode: Int): RouteResult = + completeWithStatus(jm.StatusCodes.get(statusCode)) + def completeWithStatus(statusCode: jm.StatusCode): RouteResult = + underlying.complete(statusCode.asScala) + def completeAs[T](marshaller: Marshaller[T], value: T): RouteResult = marshaller match { + case MarshallerImpl(m) ⇒ + implicit val marshaller = m(underlying.executionContext) + underlying.complete(value) + case _ ⇒ throw new IllegalArgumentException("Unsupported marshaller: $marshaller") + } + def complete(response: jm.HttpResponse): RouteResult = underlying.complete(response.asScala) + + def notFound(): RouteResult = underlying.reject() +} diff --git a/akka-http/src/main/scala/akka/http/impl/server/RouteImplementation.scala b/akka-http/src/main/scala/akka/http/impl/server/RouteImplementation.scala new file mode 100644 index 0000000000..7a70e44555 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/RouteImplementation.scala @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.language.implicitConversions +import scala.annotation.tailrec +import scala.collection.immutable +import akka.http.javadsl.model.ContentType +import akka.http.scaladsl.server.directives.{ UserCredentials, ContentTypeResolver } +import akka.http.scaladsl.server.directives.FileAndResourceDirectives.DirectoryRenderer +import akka.http.scaladsl.model.HttpHeader +import akka.http.scaladsl.model.headers.CustomHeader +import akka.http.scaladsl.server.{ Route ⇒ ScalaRoute, Directive0, Directives } +import akka.http.impl.util.JavaMapping.Implicits._ +import akka.http.scaladsl.server +import akka.http.javadsl.server._ +import RouteStructure._ + +/** + * INTERNAL API + */ +private[http] trait ExtractionMap extends CustomHeader { + def get[T](key: RequestVal[T]): Option[T] + def set[T](key: RequestVal[T], value: T): ExtractionMap +} +/** + * INTERNAL API + */ +private[http] object ExtractionMap { + implicit def apply(map: Map[RequestVal[_], Any]): ExtractionMap = + new ExtractionMap { + def get[T](key: RequestVal[T]): Option[T] = + map.get(key).asInstanceOf[Option[T]] + + def set[T](key: RequestVal[T], value: T): ExtractionMap = + ExtractionMap(map.updated(key, value)) + + // CustomHeader methods + override def suppressRendering: Boolean = true + def name(): String = "ExtractedValues" + def value(): String = "" + } +} + +/** + * INTERNAL API + */ +private[http] object RouteImplementation extends Directives with server.RouteConcatenation { + def apply(route: Route): ScalaRoute = route match { + case RouteAlternatives(children) ⇒ + val converted = children.map(RouteImplementation.apply) + converted.reduce(_ ~ _) + case RawPathPrefix(elements, children) ⇒ + val inner = apply(RouteAlternatives(children)) + + def one[T](matcher: PathMatcher[T]): Directive0 = + rawPathPrefix(matcher.asInstanceOf[PathMatcherImpl[T]].matcher) flatMap { value ⇒ + addExtraction(matcher, value) + } + elements.map(one(_)).reduce(_ & _).apply(inner) + + case GetFromResource(path, contentType, classLoader) ⇒ + getFromResource(path, contentType.asScala, classLoader) + case GetFromResourceDirectory(path, classLoader, resolver) ⇒ + getFromResourceDirectory(path, classLoader)(scalaResolver(resolver)) + case GetFromFile(file, contentType) ⇒ + getFromFile(file, contentType.asScala) + case GetFromDirectory(directory, true, resolver) ⇒ + extractExecutionContext { implicit ec ⇒ + getFromBrowseableDirectory(directory.getPath)(DirectoryRenderer.defaultDirectoryRenderer, scalaResolver(resolver)) + } + case FileAndResourceRouteWithDefaultResolver(constructor) ⇒ + RouteImplementation(constructor(new directives.ContentTypeResolver { + def resolve(fileName: String): ContentType = ContentTypeResolver.Default(fileName) + })) + + case MethodFilter(m, children) ⇒ + val inner = apply(RouteAlternatives(children)) + method(m.asScala).apply(inner) + + case Extract(extractions, children) ⇒ + val inner = apply(RouteAlternatives(children)) + extractRequestContext.flatMap { ctx ⇒ + extractions.map { e ⇒ + e.directive.flatMap(addExtraction(e.asInstanceOf[RequestVal[Any]], _)) + }.reduce(_ & _) + }.apply(inner) + + case BasicAuthentication(authenticator, children) ⇒ + val inner = apply(RouteAlternatives(children)) + authenticateBasicAsync(authenticator.realm, { creds ⇒ + val javaCreds = + creds match { + case UserCredentials.Missing ⇒ + new BasicUserCredentials { + def available: Boolean = false + def userName: String = throw new IllegalStateException("Credentials missing") + def verifySecret(secret: String): Boolean = throw new IllegalStateException("Credentials missing") + } + case p @ UserCredentials.Provided(name) ⇒ + new BasicUserCredentials { + def available: Boolean = true + def userName: String = name + def verifySecret(secret: String): Boolean = p.verifySecret(secret) + } + } + + authenticator.authenticate(javaCreds) + }).flatMap { user ⇒ + addExtraction(authenticator.asInstanceOf[RequestVal[Any]], user) + }.apply(inner) + + case EncodeResponse(coders, children) ⇒ + val scalaCoders = coders.map(_._underlyingScalaCoder()) + encodeResponseWith(scalaCoders.head, scalaCoders.tail: _*).apply(apply(RouteAlternatives(children))) + + case Conditional(eTag, lastModified, children) ⇒ + conditional(eTag.asScala, lastModified.asScala).apply(apply(RouteAlternatives(children))) + + case o: OpaqueRoute ⇒ + (ctx ⇒ o.handle(new RequestContextImpl(ctx)).asInstanceOf[RouteResultImpl].underlying) + + case p: Product ⇒ extractExecutionContext { implicit ec ⇒ complete(500, s"Not implemented: ${p.productPrefix}") } + } + + def addExtraction[T](key: RequestVal[T], value: T): Directive0 = { + @tailrec def addToExtractionMap(headers: immutable.Seq[HttpHeader], prefix: Vector[HttpHeader] = Vector.empty): immutable.Seq[HttpHeader] = + headers match { + case (m: ExtractionMap) +: rest ⇒ m.set(key, value) +: (prefix ++ rest) + case other +: rest ⇒ addToExtractionMap(rest, prefix :+ other) + case Nil ⇒ ExtractionMap(Map(key -> value)) +: prefix + } + mapRequest(_.mapHeaders(addToExtractionMap(_))) + } + + private def scalaResolver(resolver: directives.ContentTypeResolver): ContentTypeResolver = + ContentTypeResolver(f ⇒ resolver.resolve(f).asScala) +} diff --git a/akka-http/src/main/scala/akka/http/impl/server/RouteResultImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/RouteResultImpl.scala new file mode 100644 index 0000000000..2f94d482ab --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/RouteResultImpl.scala @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.language.implicitConversions +import scala.concurrent.Future +import akka.http.javadsl.{ server ⇒ js } +import akka.http.scaladsl.{ server ⇒ ss } + +/** + * INTERNAL API + */ +private[http] class RouteResultImpl(val underlying: Future[ss.RouteResult]) extends js.RouteResult +/** + * INTERNAL API + */ +private[http] object RouteResultImpl { + implicit def autoConvert(result: Future[ss.RouteResult]): js.RouteResult = + new RouteResultImpl(result) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/impl/server/RouteStructure.scala b/akka-http/src/main/scala/akka/http/impl/server/RouteStructure.scala new file mode 100644 index 0000000000..b1a1902b68 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/RouteStructure.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import java.io.File +import scala.language.existentials +import scala.collection.immutable +import akka.http.javadsl.model.{ DateTime, ContentType, HttpMethod } +import akka.http.javadsl.model.headers.EntityTag +import akka.http.javadsl.server.directives.ContentTypeResolver +import akka.http.javadsl.server._ + +/** + * INTERNAL API + */ +private[http] object RouteStructure { + trait DirectiveRoute extends Route { + def children: immutable.Seq[Route] + + require(children.nonEmpty) + } + case class RouteAlternatives(children: immutable.Seq[Route]) extends DirectiveRoute + + case class MethodFilter(method: HttpMethod, children: immutable.Seq[Route]) extends DirectiveRoute { + def filter(ctx: RequestContext): Boolean = ctx.request.method == method + } + + abstract case class FileAndResourceRouteWithDefaultResolver(routeConstructor: ContentTypeResolver ⇒ Route) extends Route + case class GetFromResource(resourcePath: String, contentType: ContentType, classLoader: ClassLoader) extends Route + case class GetFromResourceDirectory(resourceDirectory: String, classLoader: ClassLoader, resolver: ContentTypeResolver) extends Route + case class GetFromFile(file: File, contentType: ContentType) extends Route + case class GetFromDirectory(directory: File, browseable: Boolean, resolver: ContentTypeResolver) extends Route + + case class RawPathPrefix(pathElements: immutable.Seq[PathMatcher[_]], children: immutable.Seq[Route]) extends DirectiveRoute + case class Extract(extractions: Seq[StandaloneExtractionImpl[_]], children: immutable.Seq[Route]) extends DirectiveRoute + case class BasicAuthentication(authenticator: HttpBasicAuthenticator[_], children: immutable.Seq[Route]) extends DirectiveRoute + case class EncodeResponse(coders: immutable.Seq[Coder], children: immutable.Seq[Route]) extends DirectiveRoute + + case class Conditional(entityTag: EntityTag, lastModified: DateTime, children: immutable.Seq[Route]) extends DirectiveRoute + + abstract class OpaqueRoute(extractions: RequestVal[_]*) extends Route { + def handle(ctx: RequestContext): RouteResult + } +} + diff --git a/akka-http/src/main/scala/akka/http/impl/server/StandaloneExtractionImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/StandaloneExtractionImpl.scala new file mode 100644 index 0000000000..c6dde5e962 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/StandaloneExtractionImpl.scala @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.concurrent.Future +import scala.reflect.ClassTag +import akka.http.javadsl.server.RequestVal +import akka.http.scaladsl.server._ + +/** + * INTERNAL API + */ +private[http] abstract class StandaloneExtractionImpl[T: ClassTag] extends ExtractionImpl[T] with RequestVal[T] { + def directive: Directive1[T] +} + +/** + * INTERNAL API + */ +private[http] abstract class ExtractingStandaloneExtractionImpl[T: ClassTag] extends StandaloneExtractionImpl[T] { + def directive: Directive1[T] = Directives.extract(extract).flatMap(Directives.onSuccess(_)) + + def extract(ctx: RequestContext): Future[T] +} diff --git a/akka-http/src/main/scala/akka/http/impl/server/UnmarshallerImpl.scala b/akka-http/src/main/scala/akka/http/impl/server/UnmarshallerImpl.scala new file mode 100644 index 0000000000..3af1370348 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/server/UnmarshallerImpl.scala @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.impl.server + +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import akka.stream.FlowMaterializer +import akka.http.javadsl.server.Unmarshaller +import akka.http.scaladsl.unmarshalling.FromMessageUnmarshaller + +/** + * INTERNAL API + * + */ +// FIXME: too lenient visibility, currently used to implement Java marshallers, needs proper API, see #16439 +case class UnmarshallerImpl[T](scalaUnmarshaller: (ExecutionContext, FlowMaterializer) ⇒ FromMessageUnmarshaller[T])(implicit val classTag: ClassTag[T]) + extends Unmarshaller[T] diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala new file mode 100644 index 0000000000..e5cdf30d3d --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import akka.http.javadsl.server.directives._ +import scala.collection.immutable +import scala.annotation.varargs +import akka.http.javadsl.model.HttpMethods + +// FIXME: add support for the remaining directives, see #16436 +abstract class AllDirectives extends PathDirectives + +/** + * + */ +object Directives extends AllDirectives { + /** + * INTERNAL API + */ + private[http] def custom(f: immutable.Seq[Route] ⇒ Route): Directive = + new AbstractDirective { + def createRoute(first: Route, others: Array[Route]): Route = f(first +: others.toVector) + } +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Handler.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Handler.scala new file mode 100644 index 0000000000..e5d3823a90 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Handler.scala @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +/** + * A route Handler that handles a request (that is encapsulated in a [[RequestContext]]) + * and returns a [[RouteResult]] with the response (or the rejection). + * + * Use the methods in [[RequestContext]] to create a [[RouteResult]]. A handler mustn't + * return [[null]] as the result. + */ +trait Handler { + def handle(ctx: RequestContext): RouteResult +} + +/** + * A route handler with one additional argument. + */ +trait Handler1[T1] { + def handle(ctx: RequestContext, t1: T1): RouteResult +} + +/** + * A route handler with two additional arguments. + */ +trait Handler2[T1, T2] { + def handle(ctx: RequestContext, t1: T1, t2: T2): RouteResult +} + +/** + * A route handler with three additional arguments. + */ +trait Handler3[T1, T2, T3] { + def handle(ctx: RequestContext, t1: T1, t2: T2, t3: T3): RouteResult +} + +/** + * A route handler with four additional arguments. + */ +trait Handler4[T1, T2, T3, T4] { + def handle(ctx: RequestContext, t1: T1, t2: T2, t3: T3, t4: T4): RouteResult +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/HttpApp.scala b/akka-http/src/main/scala/akka/http/javadsl/server/HttpApp.scala new file mode 100644 index 0000000000..17a273dbf3 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/HttpApp.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import scala.concurrent.Future +import akka.actor.ActorSystem +import akka.http.scaladsl.Http.ServerBinding + +/** + * A convenience class to derive from to get everything from HttpService and Directives into scope. + * Implement the [[HttpApp.createRoute]] method to provide the Route and then call [[HttpApp.bindRoute]] + * to start the server on the specified interface. + */ +abstract class HttpApp + extends AllDirectives + with HttpServiceBase { + protected def createRoute(): Route + + /** + * Starts an HTTP server on the given interface and port. Creates the route by calling the + * user-implemented [[createRoute]] method and uses the route to handle requests of the server. + */ + def bindRoute(interface: String, port: Int, system: ActorSystem): Future[ServerBinding] = + bindRoute(interface, port, createRoute(), system) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/HttpBasicAuthenticator.scala b/akka-http/src/main/scala/akka/http/javadsl/server/HttpBasicAuthenticator.scala new file mode 100644 index 0000000000..c8daff1aee --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/HttpBasicAuthenticator.scala @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import akka.http.impl.server.{ ExtractionImplBase, ExtractionImpl, RouteStructure } +import akka.http.scaladsl.util.FastFuture + +import scala.annotation.varargs +import scala.concurrent.Future +import scala.reflect +import reflect.ClassTag + +/** + * Represents existing or missing HTTP Basic authentication credentials. + */ +trait BasicUserCredentials { + /** + * Were credentials provided in the request? + */ + def available: Boolean + + /** + * The username as sent in the request. + */ + def userName: String + /** + * Verifies the given secret against the one sent in the request. + */ + def verifySecret(secret: String): Boolean +} + +/** + * Implement this class to provide an HTTP Basic authentication check. The [[authenticate]] method needs to be implemented + * to check if the supplied or missing credentials are authenticated and provide a domain level object representing + * the user as a [[RequestVal]]. + */ +abstract class HttpBasicAuthenticator[T](val realm: String) extends AbstractDirective with ExtractionImplBase[T] with RequestVal[T] { + protected[http] implicit def classTag: ClassTag[T] = reflect.classTag[AnyRef].asInstanceOf[ClassTag[T]] + def authenticate(credentials: BasicUserCredentials): Future[Option[T]] + + /** + * Creates a return value for use in [[authenticate]] that successfully authenticates the requests and provides + * the given user. + */ + def authenticateAs(user: T): Future[Option[T]] = FastFuture.successful(Some(user)) + + /** + * Refuses access for this user. + */ + def refuseAccess(): Future[Option[T]] = FastFuture.successful(None) + + /** + * INTERNAL API + */ + protected[http] final def createRoute(first: Route, others: Array[Route]): Route = + RouteStructure.BasicAuthentication(this, (first +: others).toVector) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/HttpService.scala b/akka-http/src/main/scala/akka/http/javadsl/server/HttpService.scala new file mode 100644 index 0000000000..a520f4aea8 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/HttpService.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import scala.concurrent.Future +import akka.actor.ActorSystem +import akka.http.scaladsl.{ server, Http } +import akka.http.scaladsl.Http.ServerBinding +import akka.http.impl.server.RouteImplementation +import akka.stream.{ ActorFlowMaterializer, FlowMaterializer } +import akka.stream.scaladsl.{ Keep, Sink } + +trait HttpServiceBase { + /** + * Starts a server on the given interface and port and uses the route to handle incoming requests. + */ + def bindRoute(interface: String, port: Int, route: Route, system: ActorSystem): Future[ServerBinding] = { + implicit val sys = system + implicit val mat = ActorFlowMaterializer() + handleConnectionsWithRoute(interface, port, route, system, mat) + } + + /** + * Starts a server on the given interface and port and uses the route to handle incoming requests. + */ + def bindRoute(interface: String, port: Int, route: Route, system: ActorSystem, flowMaterializer: FlowMaterializer): Future[ServerBinding] = + handleConnectionsWithRoute(interface, port, route, system, flowMaterializer) + + /** + * Uses the route to handle incoming connections and requests for the ServerBinding. + */ + def handleConnectionsWithRoute(interface: String, port: Int, route: Route, system: ActorSystem, flowMaterializer: FlowMaterializer): Future[ServerBinding] = { + implicit val sys = system + implicit val mat = flowMaterializer + + import system.dispatcher + val r: server.Route = RouteImplementation(route) + Http(system).bind(interface, port).toMat(Sink.foreach(_.handleWith(r)))(Keep.left).run()(flowMaterializer) + } +} + +/** + * Provides the entrypoints to create an Http server from a route. + */ +object HttpService extends HttpServiceBase diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Marshaller.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Marshaller.scala new file mode 100644 index 0000000000..b6322cf601 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Marshaller.scala @@ -0,0 +1,11 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +/** + * A marker trait for a marshaller that converts a value of type [[T]] to an + * HttpResponse. + */ +trait Marshaller[T] \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Marshallers.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Marshallers.scala new file mode 100644 index 0000000000..11c3930ecd --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Marshallers.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import akka.http.scaladsl.marshalling.ToResponseMarshaller +import akka.http.impl.server.MarshallerImpl + +/** + * A collection of predefined marshallers. + */ +object Marshallers { + def STRING: Marshaller[String] = MarshallerImpl(implicit ctx ⇒ implicitly[ToResponseMarshaller[String]]) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Parameter.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Parameter.scala new file mode 100644 index 0000000000..05c25f0e02 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Parameter.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import java.{ lang ⇒ jl } + +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag + +import akka.http.scaladsl.server.Directive1 +import akka.http.scaladsl.server.directives.ParameterDirectives.ParamMagnet +import akka.http.scaladsl.common.ToNameReceptacleEnhancements +import akka.http.impl.server.ParameterImpl + +/** + * A RequestVal representing a query parameter of type T. + */ +trait Parameter[T] extends RequestVal[T] + +/** + * A collection of predefined parameters. + * FIXME: add tests, see #16437 + */ +object Parameters { + import ToNameReceptacleEnhancements._ + + /** + * A string query parameter. + */ + def string(name: String): Parameter[String] = + fromScalaParam(implicit ec ⇒ ParamMagnet(name)) + + /** + * An integer query parameter. + */ + def integer(name: String): Parameter[jl.Integer] = + fromScalaParam[jl.Integer](implicit ec ⇒ + ParamMagnet(name.as[Int]).asInstanceOf[ParamMagnet { type Out = Directive1[jl.Integer] }]) + + private def fromScalaParam[T: ClassTag](underlying: ExecutionContext ⇒ ParamMagnet { type Out = Directive1[T] }): Parameter[T] = + new ParameterImpl[T](underlying) +} + diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/PathMatchers.scala b/akka-http/src/main/scala/akka/http/javadsl/server/PathMatchers.scala new file mode 100644 index 0000000000..465b236343 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/PathMatchers.scala @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import java.{ lang ⇒ jl, util ⇒ ju } +import scala.reflect.ClassTag +import scala.collection.JavaConverters._ +import akka.http.impl.server.PathMatcherImpl +import akka.http.scaladsl.server.{ PathMatchers ⇒ ScalaPathMatchers, PathMatcher0, PathMatcher1 } + +/** + * A PathMatcher is used to match the (yet unmatched) URI path of incoming requests. + * It is also a RequestVal that allows to access dynamic parts of the part in a + * handler. + * + * Using a PathMatcher with the [[Directives.path]] or [[Directives.pathPrefix]] directives + * "consumes" a part of the path which is recorded in [[RequestContext.unmatchedPath]]. + */ +trait PathMatcher[T] extends RequestVal[T] + +/** + * A collection of predefined path matchers. + */ +object PathMatchers { + val NEUTRAL: PathMatcher[Void] = matcher0(_.Neutral) + val SLASH: PathMatcher[Void] = matcher0(_.Slash) + val END: PathMatcher[Void] = matcher0(_.PathEnd) + + def segment(name: String): PathMatcher[String] = matcher(_ ⇒ name -> name) + + def integerNumber: PathMatcher[jl.Integer] = matcher(_.IntNumber.asInstanceOf[PathMatcher1[jl.Integer]]) + def hexIntegerNumber: PathMatcher[jl.Integer] = matcher(_.HexIntNumber.asInstanceOf[PathMatcher1[jl.Integer]]) + + def longNumber: PathMatcher[jl.Long] = matcher(_.LongNumber.asInstanceOf[PathMatcher1[jl.Long]]) + def hexLongNumber: PathMatcher[jl.Long] = matcher(_.HexLongNumber.asInstanceOf[PathMatcher1[jl.Long]]) + + def uuid: PathMatcher[ju.UUID] = matcher(_.JavaUUID) + + def segment: PathMatcher[String] = matcher(_.Segment) + def segments: PathMatcher[ju.List[String]] = matcher(_.Segments.map(_.asJava)) + def segments(maxNumber: Int): PathMatcher[ju.List[String]] = matcher(_.Segments(maxNumber).map(_.asJava)) + + def rest: PathMatcher[String] = matcher(_.Rest) + + private def matcher[T: ClassTag](scalaMatcher: ScalaPathMatchers.type ⇒ PathMatcher1[T]): PathMatcher[T] = + new PathMatcherImpl[T](scalaMatcher(ScalaPathMatchers)) + private def matcher0(scalaMatcher: ScalaPathMatchers.type ⇒ PathMatcher0): PathMatcher[Void] = + new PathMatcherImpl[Void](scalaMatcher(ScalaPathMatchers).tmap(_ ⇒ Tuple1(null))) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/RequestContext.scala b/akka-http/src/main/scala/akka/http/javadsl/server/RequestContext.scala new file mode 100644 index 0000000000..d5c3cd4410 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/RequestContext.scala @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import akka.http.javadsl.model._ +import akka.util.ByteString + +import scala.concurrent.Future + +/** + * The RequestContext represents the state of the request while it is routed through + * the route structure. + */ +trait RequestContext { + /** + * The incoming request. + */ + def request: HttpRequest + + /** + * The still unmatched path of the request. + */ + def unmatchedPath: String + + /** + * Completes the request with a value of type T and marshals it using the given + * marshaller. + */ + def completeAs[T](marshaller: Marshaller[T], value: T): RouteResult + + /** + * Completes the request with the given response. + */ + def complete(response: HttpResponse): RouteResult + + /** + * Completes the request with the given string as an entity of type `text/plain`. + */ + def complete(text: String): RouteResult + + /** + * Completes the request with the given status code and no entity. + */ + def completeWithStatus(statusCode: StatusCode): RouteResult + + /** + * Completes the request with the given status code and no entity. + */ + def completeWithStatus(statusCode: Int): RouteResult + + /** + * Defers completion of the request + */ + def completeWith(futureResult: Future[RouteResult]): RouteResult + + /** + * Explicitly rejects the request as not found. Other route alternatives + * may still be able provide a response. + */ + def notFound(): RouteResult + + // FIXME: provide proper support for rejections, see #16438 +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/RequestVal.scala b/akka-http/src/main/scala/akka/http/javadsl/server/RequestVal.scala new file mode 100644 index 0000000000..6d07f19c20 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/RequestVal.scala @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +/** + * Represents a value that can be extracted from a request. + */ +trait RequestVal[T] { outer ⇒ + /** + * An accessor for the value given the [[RequestContext]]. + * + * Note, that some RequestVals need to be actively specified in the route structure to + * be extracted at a particular point during routing. One example is a [[PathMatcher]] + * that needs to used with a [[directives.PathDirectives]] to specify which part of the + * path should actually be extracted. Another example is an [[HttpBasicAuthenticator]] + * that needs to be used in the route explicitly to be activated. + */ + def get(ctx: RequestContext): T + + /** + * The runtime type of the extracted value. + */ + def resultClass: Class[T] +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/RequestVals.scala b/akka-http/src/main/scala/akka/http/javadsl/server/RequestVals.scala new file mode 100644 index 0000000000..0520e41a87 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/RequestVals.scala @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import java.{ util ⇒ ju } +import scala.concurrent.Future +import scala.reflect.ClassTag +import akka.http.javadsl.model.HttpMethod +import akka.http.scaladsl.server +import akka.http.scaladsl.server._ +import akka.http.scaladsl.server.directives.{ RouteDirectives, BasicDirectives } +import akka.http.impl.server.{ UnmarshallerImpl, ExtractingStandaloneExtractionImpl, RequestContextImpl, StandaloneExtractionImpl } +import akka.http.scaladsl.util.FastFuture +import akka.http.impl.util.JavaMapping.Implicits._ + +/** + * A collection of predefined [[RequestVals]]. + */ +object RequestVals { + /** + * Creates an extraction that extracts the request body using the supplied Unmarshaller. + */ + def entityAs[T](unmarshaller: Unmarshaller[T]): RequestVal[T] = + new ExtractingStandaloneExtractionImpl[T]()(unmarshaller.classTag) { + def extract(ctx: server.RequestContext): Future[T] = { + val u = unmarshaller.asInstanceOf[UnmarshallerImpl[T]].scalaUnmarshaller(ctx.executionContext, ctx.flowMaterializer) + u(ctx.request)(ctx.executionContext) + } + } + + /** + * Extracts the request method. + */ + def requestMethod: RequestVal[HttpMethod] = + new ExtractingStandaloneExtractionImpl[HttpMethod] { + def extract(ctx: server.RequestContext): Future[HttpMethod] = FastFuture.successful(ctx.request.method.asJava) + } + + /** + * Creates a new [[RequestVal]] given a [[ju.Map]] and a [[RequestVal]] that represents the key. + * The new RequestVal represents the existing value as looked up in the map. If the key doesn't + * exist the request is rejected. + */ + def lookupInMap[T, U](key: RequestVal[T], clazz: Class[U], map: ju.Map[T, U]): RequestVal[U] = + new StandaloneExtractionImpl[U]()(ClassTag(clazz)) { + import BasicDirectives._ + import RouteDirectives._ + + def directive: Directive1[U] = + extract(ctx ⇒ key.get(RequestContextImpl(ctx))).flatMap { + case key if map.containsKey(key) ⇒ provide(map.get(key)) + case _ ⇒ reject() + } + } +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Route.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Route.scala new file mode 100644 index 0000000000..6c68213120 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Route.scala @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +/** + * A marker interface to denote an element that handles a request. + * + * This is an opaque interface that cannot be implemented manually. + * Instead, see the predefined routes in [[Directives]] and use the [[Directives.handleWith]] + * method to create custom routes. + */ +trait Route \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/RouteResult.scala b/akka-http/src/main/scala/akka/http/javadsl/server/RouteResult.scala new file mode 100644 index 0000000000..223489b8cd --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/RouteResult.scala @@ -0,0 +1,11 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +/** + * A marker trait to denote the result of handling a request. Use the methods in [[RequestContext]] + * to create instances of results. + */ +trait RouteResult diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Unmarshaller.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Unmarshaller.scala new file mode 100644 index 0000000000..8e2b5e43a3 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Unmarshaller.scala @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server + +import scala.reflect.ClassTag + +/** + * A marker trait for an unmarshaller that converts an HttpRequest to a value of type T. + */ +trait Unmarshaller[T] { + def classTag: ClassTag[T] +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/BasicDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/BasicDirectives.scala new file mode 100644 index 0000000000..187066ab78 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/BasicDirectives.scala @@ -0,0 +1,180 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.directives + +import java.lang.reflect.Method +import scala.annotation.varargs +import akka.http.javadsl.server._ +import akka.http.impl.server.RouteStructure._ +import akka.http.impl.server._ + +abstract class BasicDirectives { + /** + * Tries the given routes in sequence until the first one matches. + */ + @varargs + def route(route: Route, others: Route*): Route = + RouteAlternatives(route +: others.toVector) + + /** + * A route that completes the request with a static text + */ + def complete(text: String): Route = + new OpaqueRoute() { + def handle(ctx: RequestContext): RouteResult = ctx.complete(text) + } + + /** + * A route that completes the request using the given marshaller and value. + */ + def completeAs[T](marshaller: Marshaller[T], value: T): Route = + new OpaqueRoute() { + def handle(ctx: RequestContext): RouteResult = ctx.completeAs(marshaller, value) + } + + /** + * A route that extracts a value and completes the request with it. + */ + def extractAndComplete[T](marshaller: Marshaller[T], extraction: RequestVal[T]): Route = + handle(extraction)(ctx ⇒ ctx.completeAs(marshaller, extraction.get(ctx))) + + /** + * A directive that makes sure that all the standalone extractions have been + * executed and validated. + */ + @varargs + def extractHere(extractions: RequestVal[_]*): Directive = + Directives.custom(Extract(extractions.map(_.asInstanceOf[StandaloneExtractionImpl[_ <: AnyRef]]), _)) + + /** + * A route that handles the request with the given opaque handler. Specify a set of extractions + * that will be used in the handler to make sure they are available. + */ + @varargs + def handleWith[T1](handler: Handler, extractions: RequestVal[_]*): Route = + handle(extractions: _*)(handler.handle(_)) + + /** + * A route that handles the request given the value of a single [[RequestVal]]. + */ + def handleWith[T1](e1: RequestVal[T1], handler: Handler1[T1]): Route = + handle(e1)(ctx ⇒ handler.handle(ctx, e1.get(ctx))) + + /** + * A route that handles the request given the values of the given [[RequestVal]]s. + */ + def handleWith[T1, T2](e1: RequestVal[T1], e2: RequestVal[T2], handler: Handler2[T1, T2]): Route = + handle(e1, e2)(ctx ⇒ handler.handle(ctx, e1.get(ctx), e2.get(ctx))) + + /** + * A route that handles the request given the values of the given [[RequestVal]]s. + */ + def handleWith[T1, T2, T3]( + e1: RequestVal[T1], e2: RequestVal[T2], e3: RequestVal[T3], handler: Handler3[T1, T2, T3]): Route = + handle(e1, e2, e3)(ctx ⇒ handler.handle(ctx, e1.get(ctx), e2.get(ctx), e3.get(ctx))) + + /** + * A route that handles the request given the values of the given [[RequestVal]]s. + */ + def handleWith[T1, T2, T3, T4]( + e1: RequestVal[T1], e2: RequestVal[T2], e3: RequestVal[T3], e4: RequestVal[T4], handler: Handler4[T1, T2, T3, T4]): Route = + handle(e1, e2, e3, e4)(ctx ⇒ handler.handle(ctx, e1.get(ctx), e2.get(ctx), e3.get(ctx), e4.get(ctx))) + + private[http] def handle(extractions: RequestVal[_]*)(f: RequestContext ⇒ RouteResult): Route = { + val route = + new OpaqueRoute() { + def handle(ctx: RequestContext): RouteResult = f(ctx) + } + val saExtractions = extractions.collect { case sa: StandaloneExtractionImpl[_] ⇒ sa } + if (saExtractions.isEmpty) route + else extractHere(saExtractions: _*).route(route) + } + + /** + * Handles the route by reflectively calling the instance method specified by `instance`, and `methodName`. + * Additionally, the value of all extractions will be passed to the function. + * + * For extraction types `Extraction[T1]`, `Extraction[T2]`, ... the shape of the method must match this pattern: + * + * public static RouteResult methodName(RequestContext ctx, T1 t1, T2 t2, ...) + */ + @varargs + def handleWith(instance: AnyRef, methodName: String, extractions: RequestVal[_]*): Route = + handleWith(instance.getClass, instance, methodName, extractions: _*) + + /** + * Handles the route by reflectively calling the static method specified by `clazz`, and `methodName`. + * Additionally, the value of all extractions will be passed to the function. + * + * For extraction types `Extraction[T1]`, `Extraction[T2]`, ... the shape of the method must match this pattern: + * + * public static RouteResult methodName(RequestContext ctx, T1 t1, T2 t2, ...) + */ + @varargs + def handleWith(clazz: Class[_], methodName: String, extractions: RequestVal[_]*): Route = + handleWith(clazz, null, methodName, extractions: _*) + + /** + * Handles the route by calling the method specified by `clazz`, `instance`, and `methodName`. Additionally, the value + * of all extractions will be passed to the function. + * + * For extraction types `Extraction[T1]`, `Extraction[T2]`, ... the shape of the method must match this pattern: + * + * public static RouteResult methodName(RequestContext ctx, T1 t1, T2 t2, ...) + */ + @varargs + def handleWith(clazz: Class[_], instance: AnyRef, methodName: String, extractions: RequestVal[_]*): Route = { + def chooseOverload(methods: Seq[Method]): (RequestContext, Seq[Any]) ⇒ RouteResult = { + val extractionTypes = extractions.map(_.resultClass).toList + val RequestContextClass = classOf[RequestContext] + + import java.{ lang ⇒ jl } + def paramMatches(expected: Class[_], actual: Class[_]): Boolean = expected match { + case e if e isAssignableFrom actual ⇒ true + case jl.Long.TYPE if actual == classOf[jl.Long] ⇒ true + case jl.Integer.TYPE if actual == classOf[jl.Integer] ⇒ true + case jl.Short.TYPE if actual == classOf[jl.Short] ⇒ true + case jl.Character.TYPE if actual == classOf[jl.Character] ⇒ true + case jl.Byte.TYPE if actual == classOf[jl.Byte] ⇒ true + case jl.Double.TYPE if actual == classOf[jl.Double] ⇒ true + case jl.Float.TYPE if actual == classOf[jl.Float] ⇒ true + case _ ⇒ false + } + def paramsMatch(params: Seq[Class[_]]): Boolean = { + val res = + params.size == extractionTypes.size && + (params, extractionTypes).zipped.forall(paramMatches) + + res + } + def returnTypeMatches(method: Method): Boolean = + method.getReturnType == classOf[RouteResult] + + object ParameterTypes { + def unapply(method: Method): Option[List[Class[_]]] = Some(method.getParameterTypes.toList) + } + + methods.filter(returnTypeMatches).collectFirst { + case method @ ParameterTypes(RequestContextClass :: rest) if paramsMatch(rest) ⇒ { + if (!method.isAccessible) method.setAccessible(true) // FIXME: test what happens if this fails + (ctx: RequestContext, params: Seq[Any]) ⇒ method.invoke(instance, (ctx +: params).toArray.asInstanceOf[Array[AnyRef]]: _*).asInstanceOf[RouteResult] + } + + case method @ ParameterTypes(rest) if paramsMatch(rest) ⇒ { + if (!method.isAccessible) method.setAccessible(true) + (ctx: RequestContext, params: Seq[Any]) ⇒ method.invoke(instance, params.toArray.asInstanceOf[Array[AnyRef]]: _*).asInstanceOf[RouteResult] + } + }.getOrElse(throw new RuntimeException("No suitable method found")) + } + def lookupMethod() = { + val candidateMethods = clazz.getMethods.filter(_.getName == methodName) + chooseOverload(candidateMethods) + } + + val method = lookupMethod() + + handle(extractions: _*)(ctx ⇒ method(ctx, extractions.map(_.get(ctx)))) + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/CacheConditionDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CacheConditionDirectives.scala new file mode 100644 index 0000000000..32e2b7006c --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CacheConditionDirectives.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server +package directives + +import akka.http.javadsl.model.DateTime +import akka.http.javadsl.model.headers.EntityTag +import akka.http.impl.server.RouteStructure + +import scala.annotation.varargs + +abstract class CacheConditionDirectives extends BasicDirectives { + /** + * Wraps its inner route with support for Conditional Requests as defined + * by tools.ietf.org/html/draft-ietf-httpbis-p4-conditional-26 + * + * In particular the algorithm defined by tools.ietf.org/html/draft-ietf-httpbis-p4-conditional-26#section-6 + * is implemented by this directive. + */ + @varargs + def conditional(entityTag: EntityTag, lastModified: DateTime, innerRoutes: Route*): Route = + RouteStructure.Conditional(entityTag, lastModified, innerRoutes.toVector) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/CodingDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CodingDirectives.scala new file mode 100644 index 0000000000..191ff9054c --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CodingDirectives.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.directives + +import scala.annotation.varargs +import akka.http.impl.server.RouteStructure +import akka.http.javadsl.server.{ Coder, Directive, Directives, Route } + +abstract class CodingDirectives extends CacheConditionDirectives { + /** + * Wraps the inner routes with encoding support. The response will be encoded + * using one of the predefined coders, `Gzip`, `Deflate`, or `NoCoding` depending on + * a potential [[akka.http.javadsl.model.headers.AcceptEncoding]] header from the client. + */ + @varargs def encodeResponse(innerRoutes: Route*): Route = + // FIXME: make sure this list stays synchronized with the Scala one + RouteStructure.EncodeResponse(List(Coder.NoCoding, Coder.Gzip, Coder.Deflate), innerRoutes.toVector) + + /** + * A directive that Wraps its inner routes with encoding support. + * The response will be encoded using one of the given coders with the precedence given + * by the order of the coders in this call. + * + * In any case, a potential [[akka.http.javadsl.model.headers.AcceptEncoding]] header from the client + * will be respected (or otherwise, if no matching . + */ + @varargs def encodeResponse(coders: Coder*): Directive = + Directives.custom(RouteStructure.EncodeResponse(coders.toList, _)) +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/FileAndResourceDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/FileAndResourceDirectives.scala new file mode 100644 index 0000000000..a2d28e7dc7 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/FileAndResourceDirectives.scala @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.directives + +import java.io.File +import akka.http.javadsl.model.{ MediaType, ContentType } +import akka.http.javadsl.server.Route +import akka.http.scaladsl.server +import akka.http.impl.server.RouteStructure._ + +/** + * Implement this interface to provide a custom mapping from a file name to a [[ContentType]]. + */ +trait ContentTypeResolver { + def resolve(fileName: String): ContentType +} + +/** + * A resolver that assumes the given constant [[ContentType]] for all files. + */ +case class StaticContentTypeResolver(contentType: ContentType) extends ContentTypeResolver { + def resolve(fileName: String): ContentType = contentType +} + +/** + * Allows to customize one of the predefined routes of [[FileAndResourceRoute]] to respond + * with a particular content type. + * + * The default behavior is to determine the content type by file extension. + */ +trait FileAndResourceRoute extends Route { + /** + * Returns a variant of this route that responds with the given constant [[ContentType]]. + */ + def withContentType(contentType: ContentType): Route + + /** + * Returns a variant of this route that responds with the given constant [[MediaType]]. + */ + def withContentType(mediaType: MediaType): Route + + /** + * Returns a variant of this route that uses the specified [[ContentTypeResolver]] to determine + * which [[ContentType]] to respond with by file name. + */ + def resolveContentTypeWith(resolver: ContentTypeResolver): Route +} + +object FileAndResourceRoute { + /** + * INTERNAL API + */ + private[http] def apply(f: ContentTypeResolver ⇒ Route): FileAndResourceRoute = + new FileAndResourceRouteWithDefaultResolver(f) with FileAndResourceRoute { + def withContentType(contentType: ContentType): Route = resolveContentTypeWith(StaticContentTypeResolver(contentType)) + def withContentType(mediaType: MediaType): Route = withContentType(mediaType.toContentType) + + def resolveContentTypeWith(resolver: ContentTypeResolver): Route = f(resolver) + } + + /** + * INTERNAL API + */ + private[http] def forFixedName(fileName: String)(f: ContentType ⇒ Route): FileAndResourceRoute = + new FileAndResourceRouteWithDefaultResolver(resolver ⇒ f(resolver.resolve(fileName))) with FileAndResourceRoute { + def withContentType(contentType: ContentType): Route = resolveContentTypeWith(StaticContentTypeResolver(contentType)) + def withContentType(mediaType: MediaType): Route = withContentType(mediaType.toContentType) + + def resolveContentTypeWith(resolver: ContentTypeResolver): Route = f(resolver.resolve(fileName)) + } +} + +abstract class FileAndResourceDirectives extends CodingDirectives { + /** + * Completes GET requests with the content of the given resource loaded from the default ClassLoader. + * If the resource cannot be found or read the Route rejects the request. + */ + def getFromResource(path: String): Route = + getFromResource(path, defaultClassLoader) + + /** + * Completes GET requests with the content of the given resource loaded from the given ClassLoader. + * If the resource cannot be found or read the Route rejects the request. + */ + def getFromResource(path: String, classLoader: ClassLoader): Route = + FileAndResourceRoute.forFixedName(path)(GetFromResource(path, _, classLoader)) + + /** + * Completes GET requests with the content from the resource identified by the given + * directoryPath and the unmatched path. + */ + def getFromResourceDirectory(directoryPath: String): FileAndResourceRoute = + getFromResourceDirectory(directoryPath, defaultClassLoader) + + /** + * Completes GET requests with the content from the resource identified by the given + * directoryPath and the unmatched path from the given ClassLoader. + */ + def getFromResourceDirectory(directoryPath: String, classLoader: ClassLoader): FileAndResourceRoute = + FileAndResourceRoute(GetFromResourceDirectory(directoryPath, classLoader, _)) + + /** + * Completes GET requests with the content of the given file. + */ + def getFromFile(file: File): FileAndResourceRoute = FileAndResourceRoute.forFixedName(file.getPath)(GetFromFile(file, _)) + + /** + * Completes GET requests with the content of the file at the path. + */ + def getFromFile(path: String): FileAndResourceRoute = getFromFile(new File(path)) + + /** + * Completes GET requests with the content from the file identified by the given + * directory and the unmatched path of the request. + */ + def getFromDirectory(directory: File): FileAndResourceRoute = FileAndResourceRoute(GetFromDirectory(directory, browseable = false, _)) + + /** + * Completes GET requests with the content from the file identified by the given + * directoryPath and the unmatched path of the request. + */ + def getFromDirectory(directoryPath: String): FileAndResourceRoute = getFromDirectory(new File(directoryPath)) + + /** + * Same as [[getFromDirectory]] but generates a listing of files if the path is a directory. + */ + def getFromBrowseableDirectory(directory: File): FileAndResourceRoute = FileAndResourceRoute(GetFromDirectory(directory, browseable = true, _)) + + /** + * Same as [[getFromDirectory]] but generates a listing of files if the path is a directory. + */ + def getFromBrowseableDirectory(directoryPath: String): FileAndResourceRoute = FileAndResourceRoute(GetFromDirectory(new File(directoryPath), browseable = true, _)) + + protected def defaultClassLoader: ClassLoader = server.directives.FileAndResourceDirectives.defaultClassLoader +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/MethodDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/MethodDirectives.scala new file mode 100644 index 0000000000..4f94060ec2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/MethodDirectives.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server.directives + +import akka.http.javadsl.model.{ HttpMethods, HttpMethod } +import akka.http.javadsl.server.Route +import akka.http.impl.server.RouteStructure + +import scala.annotation.varargs + +abstract class MethodDirectives extends FileAndResourceDirectives { + /** Handles the inner routes if the incoming request is a GET request, rejects the request otherwise */ + @varargs + def get(innerRoutes: Route*): Route = method(HttpMethods.GET, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a POST request, rejects the request otherwise */ + @varargs + def post(innerRoutes: Route*): Route = method(HttpMethods.POST, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a PUT request, rejects the request otherwise */ + @varargs + def put(innerRoutes: Route*): Route = method(HttpMethods.PUT, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a DELETE request, rejects the request otherwise */ + @varargs + def delete(innerRoutes: Route*): Route = method(HttpMethods.DELETE, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a HEAD request, rejects the request otherwise */ + @varargs + def head(innerRoutes: Route*): Route = method(HttpMethods.HEAD, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a OPTIONS request, rejects the request otherwise */ + @varargs + def options(innerRoutes: Route*): Route = method(HttpMethods.OPTIONS, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a PATCH request, rejects the request otherwise */ + @varargs + def patch(innerRoutes: Route*): Route = method(HttpMethods.PATCH, innerRoutes: _*) + + /** Handles the inner routes if the incoming request is a request with the given method, rejects the request otherwise */ + @varargs + def method(method: HttpMethod, innerRoutes: Route*): Route = RouteStructure.MethodFilter(method, innerRoutes.toVector) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/PathDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/PathDirectives.scala new file mode 100644 index 0000000000..231b9c42ab --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/PathDirectives.scala @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.javadsl.server +package directives + +import akka.http.impl.server.RouteStructure + +import scala.annotation.varargs +import scala.collection.immutable + +abstract class PathDirectives extends MethodDirectives { + /** + * Tries to consumes the complete unmatched path given a number of PathMatchers. Between each + * of the matchers a `/` will be matched automatically. + * + * A matcher can either be a matcher of type `PathMatcher`, or a literal string. + */ + @varargs + def path(matchers: AnyRef*): Directive = + forMatchers(joinWithSlash(convertMatchers(matchers)) :+ PathMatchers.END) + + @varargs + def pathPrefix(matchers: AnyRef*): Directive = + forMatchers(joinWithSlash(convertMatchers(matchers))) + + def pathSingleSlash: Directive = forMatchers(List(PathMatchers.SLASH, PathMatchers.END)) + + @varargs + def rawPathPrefix(matchers: AnyRef*): Directive = + forMatchers(convertMatchers(matchers)) + + private def forMatchers(matchers: immutable.Seq[PathMatcher[_]]): Directive = + Directives.custom(RouteStructure.RawPathPrefix(matchers, _)) + + private def joinWithSlash(matchers: immutable.Seq[PathMatcher[_]]): immutable.Seq[PathMatcher[_]] = { + def join(result: immutable.Seq[PathMatcher[_]], next: PathMatcher[_]): immutable.Seq[PathMatcher[_]] = + result :+ PathMatchers.SLASH :+ next + + matchers.foldLeft(immutable.Seq.empty[PathMatcher[_]])(join) + } + + private def convertMatchers(matchers: Seq[AnyRef]): immutable.Seq[PathMatcher[_]] = { + def parse(matcher: AnyRef): PathMatcher[_] = matcher match { + case p: PathMatcher[_] ⇒ p + case name: String ⇒ PathMatchers.segment(name) + } + + matchers.map(parse).toVector + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/client/RequestBuilding.scala b/akka-http/src/main/scala/akka/http/scaladsl/client/RequestBuilding.scala new file mode 100644 index 0000000000..8f41713983 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/client/RequestBuilding.scala @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.client + +import scala.collection.immutable +import scala.concurrent.{ Await, ExecutionContext } +import scala.concurrent.duration._ +import scala.reflect.ClassTag +import akka.util.Timeout +import akka.event.{ Logging, LoggingAdapter } +import akka.http.scaladsl.marshalling._ +import akka.http.scaladsl.model._ +import headers.HttpCredentials +import HttpMethods._ + +trait RequestBuilding extends TransformerPipelineSupport { + type RequestTransformer = HttpRequest ⇒ HttpRequest + + class RequestBuilder(val method: HttpMethod) { + def apply(): HttpRequest = + apply("/") + + def apply(uri: String): HttpRequest = + apply(uri, HttpEntity.Empty) + + def apply[T](uri: String, content: T)(implicit m: ToEntityMarshaller[T], ec: ExecutionContext): HttpRequest = + apply(uri, Some(content)) + + def apply[T](uri: String, content: Option[T])(implicit m: ToEntityMarshaller[T], ec: ExecutionContext): HttpRequest = + apply(Uri(uri), content) + + def apply(uri: String, entity: RequestEntity): HttpRequest = + apply(Uri(uri), entity) + + def apply(uri: Uri): HttpRequest = + apply(uri, HttpEntity.Empty) + + def apply[T](uri: Uri, content: T)(implicit m: ToEntityMarshaller[T], ec: ExecutionContext): HttpRequest = + apply(uri, Some(content)) + + def apply[T](uri: Uri, content: Option[T])(implicit m: ToEntityMarshaller[T], timeout: Timeout = Timeout(1.second), ec: ExecutionContext): HttpRequest = + content match { + case None ⇒ apply(uri, HttpEntity.Empty) + case Some(value) ⇒ + val entity = Await.result(Marshal(value).to[RequestEntity], timeout.duration) + apply(uri, entity) + } + + def apply(uri: Uri, entity: RequestEntity): HttpRequest = + HttpRequest(method, uri, Nil, entity) + } + + val Get = new RequestBuilder(GET) + val Post = new RequestBuilder(POST) + val Put = new RequestBuilder(PUT) + val Patch = new RequestBuilder(PATCH) + val Delete = new RequestBuilder(DELETE) + val Options = new RequestBuilder(OPTIONS) + val Head = new RequestBuilder(HEAD) + + // TODO: reactivate after HTTP message encoding has been ported + //def encode(encoder: Encoder): RequestTransformer = encoder.encode(_, flow) + + def addHeader(header: HttpHeader): RequestTransformer = _.mapHeaders(header +: _) + + def addHeader(headerName: String, headerValue: String): RequestTransformer = + HttpHeader.parse(headerName, headerValue) match { + case HttpHeader.ParsingResult.Ok(h, Nil) ⇒ addHeader(h) + case result ⇒ throw new IllegalArgumentException(result.errors.head.formatPretty) + } + + def addHeaders(first: HttpHeader, more: HttpHeader*): RequestTransformer = _.mapHeaders(_ ++ (first +: more)) + + def mapHeaders(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): RequestTransformer = _.mapHeaders(f) + + def removeHeader(headerName: String): RequestTransformer = + _ mapHeaders (_ filterNot (_.name equalsIgnoreCase headerName)) + + def removeHeader[T <: HttpHeader: ClassTag]: RequestTransformer = + removeHeader(implicitly[ClassTag[T]].runtimeClass) + + def removeHeader(clazz: Class[_]): RequestTransformer = + _ mapHeaders (_ filterNot clazz.isInstance) + + def removeHeaders(names: String*): RequestTransformer = + _ mapHeaders (_ filterNot (header ⇒ names exists (_ equalsIgnoreCase header.name))) + + def addCredentials(credentials: HttpCredentials) = addHeader(headers.Authorization(credentials)) + + def logRequest(log: LoggingAdapter, level: Logging.LogLevel = Logging.DebugLevel) = logValue[HttpRequest](log, level) + + def logRequest(logFun: HttpRequest ⇒ Unit) = logValue[HttpRequest](logFun) + + implicit def header2AddHeader(header: HttpHeader): RequestTransformer = addHeader(header) +} + +object RequestBuilding extends RequestBuilding \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/client/TransformerPipelineSupport.scala b/akka-http/src/main/scala/akka/http/scaladsl/client/TransformerPipelineSupport.scala new file mode 100644 index 0000000000..e43dcab6bc --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/client/TransformerPipelineSupport.scala @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.client + +import scala.concurrent.{ Future, ExecutionContext } +import akka.event.{ Logging, LoggingAdapter } + +trait TransformerPipelineSupport { + + def logValue[T](log: LoggingAdapter, level: Logging.LogLevel = Logging.DebugLevel): T ⇒ T = + logValue { value ⇒ log.log(level, value.toString) } + + def logValue[T](logFun: T ⇒ Unit): T ⇒ T = { response ⇒ + logFun(response) + response + } + + implicit class WithTransformation[A](value: A) { + def ~>[B](f: A ⇒ B): B = f(value) + } + + implicit class WithTransformerConcatenation[A, B](f: A ⇒ B) extends (A ⇒ B) { + def apply(input: A) = f(input) + def ~>[AA, BB, R](g: AA ⇒ BB)(implicit aux: TransformerAux[A, B, AA, BB, R]) = + new WithTransformerConcatenation[A, R](aux(f, g)) + } +} + +object TransformerPipelineSupport extends TransformerPipelineSupport + +trait TransformerAux[A, B, AA, BB, R] { + def apply(f: A ⇒ B, g: AA ⇒ BB): A ⇒ R +} + +object TransformerAux { + implicit def aux1[A, B, C] = new TransformerAux[A, B, B, C, C] { + def apply(f: A ⇒ B, g: B ⇒ C): A ⇒ C = f andThen g + } + implicit def aux2[A, B, C](implicit ec: ExecutionContext) = + new TransformerAux[A, Future[B], B, C, Future[C]] { + def apply(f: A ⇒ Future[B], g: B ⇒ C): A ⇒ Future[C] = f(_).map(g) + } + implicit def aux3[A, B, C](implicit ec: ExecutionContext) = + new TransformerAux[A, Future[B], B, Future[C], Future[C]] { + def apply(f: A ⇒ Future[B], g: B ⇒ Future[C]): A ⇒ Future[C] = f(_).flatMap(g) + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Coder.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Coder.scala new file mode 100644 index 0000000000..9186cde443 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Coder.scala @@ -0,0 +1,8 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +/** Marker trait for A combined Encoder and Decoder */ +trait Coder extends Encoder with Decoder diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/DataMapper.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/DataMapper.scala new file mode 100644 index 0000000000..71e7370c62 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/DataMapper.scala @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.http.scaladsl.model.{ HttpRequest, HttpResponse, ResponseEntity, RequestEntity } +import akka.util.ByteString +import akka.stream.scaladsl.Flow + +/** An abstraction to transform data bytes of HttpMessages or HttpEntities */ +sealed trait DataMapper[T] { + def transformDataBytes(t: T, transformer: Flow[ByteString, ByteString, _]): T +} +object DataMapper { + implicit val mapRequestEntity: DataMapper[RequestEntity] = + new DataMapper[RequestEntity] { + def transformDataBytes(t: RequestEntity, transformer: Flow[ByteString, ByteString, _]): RequestEntity = + t.transformDataBytes(transformer) + } + implicit val mapResponseEntity: DataMapper[ResponseEntity] = + new DataMapper[ResponseEntity] { + def transformDataBytes(t: ResponseEntity, transformer: Flow[ByteString, ByteString, _]): ResponseEntity = + t.transformDataBytes(transformer) + } + + implicit val mapRequest: DataMapper[HttpRequest] = mapMessage(mapRequestEntity)((m, f) ⇒ m.withEntity(f(m.entity))) + implicit val mapResponse: DataMapper[HttpResponse] = mapMessage(mapResponseEntity)((m, f) ⇒ m.withEntity(f(m.entity))) + + def mapMessage[T, E](entityMapper: DataMapper[E])(mapEntity: (T, E ⇒ E) ⇒ T): DataMapper[T] = + new DataMapper[T] { + def transformDataBytes(t: T, transformer: Flow[ByteString, ByteString, _]): T = + mapEntity(t, entityMapper.transformDataBytes(_, transformer)) + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala new file mode 100644 index 0000000000..4787fbb80b --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Decoder.scala @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.http.scaladsl.model._ +import akka.stream.FlowMaterializer +import akka.stream.stage.Stage +import akka.util.ByteString +import headers.HttpEncoding +import akka.stream.scaladsl.{ Sink, Source, Flow } + +import scala.concurrent.Future + +trait Decoder { + def encoding: HttpEncoding + + def decode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = + if (message.headers exists Encoder.isContentEncodingHeader) + decodeData(message).withHeaders(message.headers filterNot Encoder.isContentEncodingHeader) + else message.self + + def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = mapper.transformDataBytes(t, decoderFlow) + + def maxBytesPerChunk: Int + def withMaxBytesPerChunk(maxBytesPerChunk: Int): Decoder + + def decoderFlow: Flow[ByteString, ByteString, Unit] + def decode(input: ByteString)(implicit mat: FlowMaterializer): Future[ByteString] = + Source.single(input).via(decoderFlow).runWith(Sink.head) +} +object Decoder { + val MaxBytesPerChunkDefault: Int = 65536 +} + +/** A decoder that is implemented in terms of a [[Stage]] */ +trait StreamDecoder extends Decoder { outer ⇒ + protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] + + def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault + def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder = + new StreamDecoder { + def encoding: HttpEncoding = outer.encoding + override def maxBytesPerChunk: Int = newMaxBytesPerChunk + + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + outer.newDecompressorStage(maxBytesPerChunk) + } + + def decoderFlow: Flow[ByteString, ByteString, Unit] = + Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk)) + +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala new file mode 100644 index 0000000000..adf7325915 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Deflate.scala @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import java.util.zip.{ Inflater, Deflater } +import akka.stream.stage._ +import akka.util.{ ByteStringBuilder, ByteString } + +import scala.annotation.tailrec +import akka.http.impl.util._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.HttpEncodings + +class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder { + val encoding = HttpEncodings.deflate + def newCompressor = new DeflateCompressor + def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new DeflateDecompressor(maxBytesPerChunk) +} +object Deflate extends Deflate(Encoder.DefaultFilter) + +class DeflateCompressor extends Compressor { + protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false) + + override final def compressAndFlush(input: ByteString): ByteString = { + val buffer = newTempBuffer(input.size) + + compressWithBuffer(input, buffer) ++ flushWithBuffer(buffer) + } + override final def compressAndFinish(input: ByteString): ByteString = { + val buffer = newTempBuffer(input.size) + + compressWithBuffer(input, buffer) ++ finishWithBuffer(buffer) + } + override final def compress(input: ByteString): ByteString = compressWithBuffer(input, newTempBuffer()) + override final def flush(): ByteString = flushWithBuffer(newTempBuffer()) + override final def finish(): ByteString = finishWithBuffer(newTempBuffer()) + + protected def compressWithBuffer(input: ByteString, buffer: Array[Byte]): ByteString = { + assert(deflater.needsInput()) + deflater.setInput(input.toArray) + drain(buffer) + } + protected def flushWithBuffer(buffer: Array[Byte]): ByteString = { + // trick the deflater into flushing: switch compression level + // FIXME: use proper APIs and SYNC_FLUSH when Java 6 support is dropped + deflater.deflate(EmptyByteArray, 0, 0) + deflater.setLevel(Deflater.NO_COMPRESSION) + val res1 = drain(buffer) + deflater.setLevel(Deflater.BEST_COMPRESSION) + val res2 = drain(buffer) + res1 ++ res2 + } + protected def finishWithBuffer(buffer: Array[Byte]): ByteString = { + deflater.finish() + val res = drain(buffer) + deflater.end() + res + } + + @tailrec + protected final def drain(buffer: Array[Byte], result: ByteStringBuilder = new ByteStringBuilder()): ByteString = { + val len = deflater.deflate(buffer) + if (len > 0) { + result ++= ByteString.fromArray(buffer, 0, len) + drain(buffer, result) + } else { + assert(deflater.needsInput()) + result.result() + } + } + + private def newTempBuffer(size: Int = 65536): Array[Byte] = + // The default size is somewhat arbitrary, we'd like to guess a better value but Deflater/zlib + // is buffering in an unpredictable manner. + // `compress` will only return any data if the buffered compressed data has some size in + // the region of 10000-50000 bytes. + // `flush` and `finish` will return any size depending on the previous input. + // This value will hopefully provide a good compromise between memory churn and + // excessive fragmentation of ByteStrings. + new Array[Byte](size) +} + +class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { + protected def createInflater() = new Inflater() + + def initial: State = StartInflate + def afterInflate: State = StartInflate + + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {} + protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.finish() +} + +abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] { + protected def createInflater(): Inflater + val inflater = createInflater() + + protected def afterInflate: State + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit + + /** Start inflating */ + case object StartInflate extends IntermediateState { + def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = { + require(inflater.needsInput()) + inflater.setInput(data.toArray) + + becomeWithRemaining(Inflate()(data), ByteString.empty, ctx) + } + } + + /** Inflate */ + case class Inflate()(data: ByteString) extends IntermediateState { + override def onPull(ctx: Context[ByteString]): SyncDirective = { + val buffer = new Array[Byte](maxBytesPerChunk) + val read = inflater.inflate(buffer) + if (read > 0) { + afterBytesRead(buffer, 0, read) + ctx.push(ByteString.fromArray(buffer, 0, read)) + } else { + val remaining = data.takeRight(inflater.getRemaining) + val next = + if (inflater.finished()) afterInflate + else StartInflate + + becomeWithRemaining(next, remaining, ctx) + } + } + def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = + throw new IllegalStateException("Don't expect a new Element") + } + + def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = { + become(next) + if (remaining.isEmpty) current.onPull(ctx) + else current.onPush(remaining, ctx) + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala new file mode 100644 index 0000000000..c10908b0af --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Encoder.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.http.scaladsl.model._ +import akka.http.impl.util.StreamUtils +import akka.stream.stage.Stage +import akka.util.ByteString +import headers._ +import akka.stream.scaladsl.Flow + +trait Encoder { + def encoding: HttpEncoding + + def messageFilter: HttpMessage ⇒ Boolean + + def encode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = + if (messageFilter(message) && !message.headers.exists(Encoder.isContentEncodingHeader)) + encodeData(message).withHeaders(`Content-Encoding`(encoding) +: message.headers) + else message.self + + def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T = + mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer)) + + def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input) + + def newCompressor: Compressor + + def newEncodeTransformer(): Stage[ByteString, ByteString] = { + val compressor = newCompressor + + def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes) + def finish(): ByteString = compressor.finish() + + StreamUtils.byteStringTransformer(encodeChunk, finish) + } +} + +object Encoder { + val DefaultFilter: HttpMessage ⇒ Boolean = { + case req: HttpRequest ⇒ isCompressible(req) + case res @ HttpResponse(status, _, _, _) ⇒ isCompressible(res) && status.isSuccess + } + private[coding] def isCompressible(msg: HttpMessage): Boolean = + msg.entity.contentType.mediaType.compressible + + private[coding] val isContentEncodingHeader: HttpHeader ⇒ Boolean = _.isInstanceOf[`Content-Encoding`] +} + +/** A stateful object representing ongoing compression. */ +abstract class Compressor { + /** + * Compresses the given input and returns compressed data. The implementation + * can and will choose to buffer output data to improve compression. Use + * `flush` or `compressAndFlush` to make sure that all input data has been + * compressed and pending output data has been returned. + */ + def compress(input: ByteString): ByteString + + /** + * Flushes any output data and returns the currently remaining compressed data. + */ + def flush(): ByteString + + /** + * Closes this compressed stream and return the remaining compressed data. After + * calling this method, this Compressor cannot be used any further. + */ + def finish(): ByteString + + /** Combines `compress` + `flush` */ + def compressAndFlush(input: ByteString): ByteString + /** Combines `compress` + `finish` */ + def compressAndFinish(input: ByteString): ByteString +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala new file mode 100644 index 0000000000..8176a9bf36 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/Gzip.scala @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.util.ByteString +import akka.stream.stage._ + +import akka.http.impl.util.ByteReader +import java.util.zip.{ Inflater, CRC32, ZipException, Deflater } + +import akka.http.scaladsl.model._ +import headers.HttpEncodings + +class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder { + val encoding = HttpEncodings.gzip + def newCompressor = new GzipCompressor + def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new GzipDecompressor(maxBytesPerChunk) +} + +/** + * An encoder and decoder for the HTTP 'gzip' encoding. + */ +object Gzip extends Gzip(Encoder.DefaultFilter) { + def apply(messageFilter: HttpMessage ⇒ Boolean) = new Gzip(messageFilter) +} + +class GzipCompressor extends DeflateCompressor { + override protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, true) + private val checkSum = new CRC32 // CRC32 of uncompressed data + private var headerSent = false + private var bytesRead = 0L + + override protected def compressWithBuffer(input: ByteString, buffer: Array[Byte]): ByteString = { + updateCrc(input) + header() ++ super.compressWithBuffer(input, buffer) + } + override protected def flushWithBuffer(buffer: Array[Byte]): ByteString = header() ++ super.flushWithBuffer(buffer) + override protected def finishWithBuffer(buffer: Array[Byte]): ByteString = super.finishWithBuffer(buffer) ++ trailer() + + private def updateCrc(input: ByteString): Unit = { + checkSum.update(input.toArray) + bytesRead += input.length + } + private def header(): ByteString = + if (!headerSent) { + headerSent = true + GzipDecompressor.Header + } else ByteString.empty + + private def trailer(): ByteString = { + def int32(i: Int): ByteString = ByteString(i, i >> 8, i >> 16, i >> 24) + val crc = checkSum.getValue.toInt + val tot = bytesRead.toInt // truncated to 32bit as specified in https://tools.ietf.org/html/rfc1952#section-2 + val trailer = int32(crc) ++ int32(tot) + + trailer + } +} + +class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { + protected def createInflater(): Inflater = new Inflater(true) + + def initial: State = Initial + + /** No bytes were received yet */ + case object Initial extends State { + def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = + if (data.isEmpty) ctx.pull() + else becomeWithRemaining(ReadHeaders, data, ctx) + + override def onPull(ctx: Context[ByteString]): SyncDirective = + if (ctx.isFinishing) { + ctx.finish() + } else super.onPull(ctx) + } + + var crc32: CRC32 = new CRC32 + protected def afterInflate: State = ReadTrailer + + /** Reading the header bytes */ + case object ReadHeaders extends ByteReadingState { + def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { + import reader._ + + if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header + if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method + val flags = readByte() + skip(6) // skip MTIME, XFL and OS fields + if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields + if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name + if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment + if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header") + + inflater.reset() + crc32.reset() + becomeWithRemaining(StartInflate, remainingData, ctx) + } + } + + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = + crc32.update(buffer, offset, length) + + /** Reading the trailer */ + case object ReadTrailer extends ByteReadingState { + def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { + import reader._ + + if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") + if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") + + becomeWithRemaining(Initial, remainingData, ctx) + } + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() + + private def crc16(data: ByteString) = { + val crc = new CRC32 + crc.update(data.toArray) + crc.getValue.toInt & 0xFFFF + } + + override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream")) + + private def fail(msg: String) = throw new ZipException(msg) +} + +/** INTERNAL API */ +private[http] object GzipDecompressor { + // RFC 1952: http://tools.ietf.org/html/rfc1952 section 2.2 + val Header = ByteString( + 0x1F, // ID1 + 0x8B, // ID2 + 8, // CM = Deflate + 0, // FLG + 0, // MTIME 1 + 0, // MTIME 2 + 0, // MTIME 3 + 0, // MTIME 4 + 0, // XFL + 0 // OS + ) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala b/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala new file mode 100644 index 0000000000..ba59b93fd4 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/coding/NoCoding.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.coding + +import akka.http.scaladsl.model._ +import akka.http.impl.util.StreamUtils +import akka.stream.stage.Stage +import akka.util.ByteString +import headers.HttpEncodings + +/** + * An encoder and decoder for the HTTP 'identity' encoding. + */ +object NoCoding extends Coder with StreamDecoder { + val encoding = HttpEncodings.identity + + override def encode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = message.self + override def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T = t + override def decode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = message.self + override def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = t + + val messageFilter: HttpMessage ⇒ Boolean = _ ⇒ false + + def newCompressor = NoCodingCompressor + + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + () ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk) +} + +object NoCodingCompressor extends Compressor { + def compress(input: ByteString): ByteString = input + def flush() = ByteString.empty + def finish() = ByteString.empty + + def compressAndFlush(input: ByteString): ByteString = input + def compressAndFinish(input: ByteString): ByteString = input +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/common/NameReceptacle.scala b/akka-http/src/main/scala/akka/http/scaladsl/common/NameReceptacle.scala new file mode 100644 index 0000000000..42262999c7 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/common/NameReceptacle.scala @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.common + +import akka.http.scaladsl.unmarshalling.{ FromStringUnmarshaller ⇒ FSU } + +private[http] trait ToNameReceptacleEnhancements { + implicit def symbol2NR(symbol: Symbol) = new NameReceptacle[String](symbol.name) + implicit def string2NR(string: String) = new NameReceptacle[String](string) +} +object ToNameReceptacleEnhancements extends ToNameReceptacleEnhancements + +class NameReceptacle[T](val name: String) { + def as[B] = new NameReceptacle[B](name) + def as[B](unmarshaller: FSU[B]) = new NameUnmarshallerReceptacle(name, unmarshaller) + def ? = new NameOptionReceptacle[T](name) + def ?[B](default: B) = new NameDefaultReceptacle(name, default) + def ![B](requiredValue: B) = new RequiredValueReceptacle(name, requiredValue) +} + +class NameUnmarshallerReceptacle[T](val name: String, val um: FSU[T]) { + def ? = new NameOptionUnmarshallerReceptacle[T](name, um) + def ?(default: T) = new NameDefaultUnmarshallerReceptacle(name, default, um) + def !(requiredValue: T) = new RequiredValueUnmarshallerReceptacle(name, requiredValue, um) +} + +class NameOptionReceptacle[T](val name: String) + +class NameDefaultReceptacle[T](val name: String, val default: T) + +class RequiredValueReceptacle[T](val name: String, val requiredValue: T) + +class NameOptionUnmarshallerReceptacle[T](val name: String, val um: FSU[T]) + +class NameDefaultUnmarshallerReceptacle[T](val name: String, val default: T, val um: FSU[T]) + +class RequiredValueUnmarshallerReceptacle[T](val name: String, val requiredValue: T, val um: FSU[T]) \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/common/StrictForm.scala b/akka-http/src/main/scala/akka/http/scaladsl/common/StrictForm.scala new file mode 100644 index 0000000000..26bbb789be --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/common/StrictForm.scala @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.common + +import scala.annotation.implicitNotFound +import scala.collection.immutable +import scala.concurrent.{ ExecutionContext, Future } +import scala.concurrent.duration._ +import akka.stream.FlowMaterializer +import akka.http.scaladsl.unmarshalling._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture +import FastFuture._ + +/** + * Read-only abstraction on top of `application/x-www-form-urlencoded` and multipart form data, + * allowing joint unmarshalling access to either kind, **if** you supply both, a [[FromStringUnmarshaller]] + * as well as a [[FromEntityUnmarshaller]] for the target type `T`. + * Note: In order to allow for random access to the field values streamed multipart form data are strictified! + * Don't use this abstraction on potentially unbounded forms (e.g. large file uploads). + * + * If you only need to consume one type of form (`application/x-www-form-urlencoded` *or* multipart) then + * simply unmarshal directly to the respective form abstraction ([[FormData]] or [[Multipart.FormData]]) + * rather than going through [[StrictForm]]. + * + * Simple usage example: + * {{{ + * val strictFormFuture = Unmarshal(entity).to[StrictForm] + * val fooFieldUnmarshalled: Future[T] = + * strictFormFuture flatMap { form => + * Unmarshal(form field "foo").to[T] + * } + * }}} + */ +sealed abstract class StrictForm { + def fields: immutable.Seq[(String, StrictForm.Field)] + def field(name: String): Option[StrictForm.Field] = fields collectFirst { case (`name`, field) ⇒ field } +} + +object StrictForm { + sealed trait Field + object Field { + private[StrictForm] final case class FromString(value: String) extends Field + private[StrictForm] final case class FromPart(value: Multipart.FormData.BodyPart.Strict) extends Field + + implicit def unmarshaller[T](implicit um: FieldUnmarshaller[T]): FromStrictFormFieldUnmarshaller[T] = + Unmarshaller(implicit ec ⇒ { + case FromString(value) ⇒ um.unmarshalString(value) + case FromPart(value) ⇒ um.unmarshalPart(value) + }) + + def unmarshallerFromFSU[T](fsu: FromStringUnmarshaller[T]): FromStrictFormFieldUnmarshaller[T] = + Unmarshaller(implicit ec ⇒ { + case FromString(value) ⇒ fsu(value) + case FromPart(value) ⇒ fsu(value.entity.data.decodeString(value.entity.contentType.charset.nioCharset.name)) + }) + + @implicitNotFound("In order to unmarshal a `StrictForm.Field` to type `${T}` you need to supply a " + + "`FromStringUnmarshaller[${T}]` and/or a `FromEntityUnmarshaller[${T}]`") + sealed trait FieldUnmarshaller[T] { + def unmarshalString(value: String)(implicit ec: ExecutionContext): Future[T] + def unmarshalPart(value: Multipart.FormData.BodyPart.Strict)(implicit ec: ExecutionContext): Future[T] + } + object FieldUnmarshaller extends LowPrioImplicits { + implicit def fromBoth[T](implicit fsu: FromStringUnmarshaller[T], feu: FromEntityUnmarshaller[T]) = + new FieldUnmarshaller[T] { + def unmarshalString(value: String)(implicit ec: ExecutionContext) = fsu(value) + def unmarshalPart(value: Multipart.FormData.BodyPart.Strict)(implicit ec: ExecutionContext) = feu(value.entity) + } + } + sealed abstract class LowPrioImplicits { + implicit def fromFSU[T](implicit fsu: FromStringUnmarshaller[T]) = + new FieldUnmarshaller[T] { + def unmarshalString(value: String)(implicit ec: ExecutionContext) = fsu(value) + def unmarshalPart(value: Multipart.FormData.BodyPart.Strict)(implicit ec: ExecutionContext) = + fsu(value.entity.data.decodeString(value.entity.contentType.charset.nioCharset.name)) + } + implicit def fromFEU[T](implicit feu: FromEntityUnmarshaller[T]) = + new FieldUnmarshaller[T] { + def unmarshalString(value: String)(implicit ec: ExecutionContext) = feu(HttpEntity(value)) + def unmarshalPart(value: Multipart.FormData.BodyPart.Strict)(implicit ec: ExecutionContext) = feu(value.entity) + } + } + } + + implicit def unmarshaller(implicit formDataUM: FromEntityUnmarshaller[FormData], + multipartUM: FromEntityUnmarshaller[Multipart.FormData], + fm: FlowMaterializer): FromEntityUnmarshaller[StrictForm] = + Unmarshaller { implicit ec ⇒ + entity ⇒ + + def tryUnmarshalToQueryForm: Future[StrictForm] = + for (formData ← formDataUM(entity).fast) yield { + new StrictForm { + val fields = formData.fields.map { case (name, value) ⇒ name -> Field.FromString(value) }(collection.breakOut) + } + } + + def tryUnmarshalToMultipartForm: Future[StrictForm] = + for { + multiPartFD ← multipartUM(entity).fast + strictMultiPartFD ← multiPartFD.toStrict(10.seconds).fast // TODO: make timeout configurable + } yield { + new StrictForm { + val fields = strictMultiPartFD.strictParts.map { + case x: Multipart.FormData.BodyPart.Strict ⇒ x.name -> Field.FromPart(x) + }(collection.breakOut) + } + } + + tryUnmarshalToQueryForm.fast.recoverWith { + case Unmarshaller.UnsupportedContentTypeException(supported1) ⇒ + tryUnmarshalToMultipartForm.fast.recoverWith { + case Unmarshaller.UnsupportedContentTypeException(supported2) ⇒ + FastFuture.failed(Unmarshaller.UnsupportedContentTypeException(supported1 ++ supported2)) + } + } + } + + /** + * Simple model for strict file content in a multipart form data part. + */ + final case class FileData(filename: Option[String], entity: HttpEntity.Strict) + + object FileData { + implicit val unmarshaller: FromStrictFormFieldUnmarshaller[FileData] = + Unmarshaller strict { + case Field.FromString(_) ⇒ throw Unmarshaller.UnsupportedContentTypeException(MediaTypes.`application/x-www-form-urlencoded`) + case Field.FromPart(part) ⇒ FileData(part.filename, part.entity) + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/EmptyValue.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/EmptyValue.scala new file mode 100644 index 0000000000..394b0575c2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/EmptyValue.scala @@ -0,0 +1,16 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.collection.immutable +import akka.http.scaladsl.model._ + +class EmptyValue[+T] private (val emptyValue: T) + +object EmptyValue { + implicit def emptyEntity = new EmptyValue[UniversalEntity](HttpEntity.Empty) + implicit val emptyHeadersAndEntity = new EmptyValue[(immutable.Seq[HttpHeader], UniversalEntity)](Nil -> HttpEntity.Empty) + implicit val emptyResponse = new EmptyValue[HttpResponse](HttpResponse(entity = emptyEntity.emptyValue)) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/GenericMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/GenericMarshallers.scala new file mode 100644 index 0000000000..2c2b70b21f --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/GenericMarshallers.scala @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.Future +import scala.util.{ Try, Failure, Success } +import akka.http.scaladsl.util.FastFuture +import FastFuture._ + +trait GenericMarshallers extends LowPriorityToResponseMarshallerImplicits { + + implicit def throwableMarshaller[T]: Marshaller[Throwable, T] = Marshaller(_ ⇒ FastFuture.failed) + + implicit def optionMarshaller[A, B](implicit m: Marshaller[A, B], empty: EmptyValue[B]): Marshaller[Option[A], B] = + Marshaller { implicit ec ⇒ + { + case Some(value) ⇒ m(value) + case None ⇒ FastFuture.successful(Marshalling.Opaque(() ⇒ empty.emptyValue) :: Nil) + } + } + + implicit def eitherMarshaller[A1, A2, B](implicit m1: Marshaller[A1, B], m2: Marshaller[A2, B]): Marshaller[Either[A1, A2], B] = + Marshaller { implicit ec ⇒ + { + case Left(a1) ⇒ m1(a1) + case Right(a2) ⇒ m2(a2) + } + } + + implicit def futureMarshaller[A, B](implicit m: Marshaller[A, B]): Marshaller[Future[A], B] = + Marshaller(implicit ec ⇒ _.fast.flatMap(m(_))) + + implicit def tryMarshaller[A, B](implicit m: Marshaller[A, B]): Marshaller[Try[A], B] = + Marshaller { implicit ec ⇒ + { + case Success(value) ⇒ m(value) + case Failure(error) ⇒ FastFuture.failed(error) + } + } +} + +object GenericMarshallers extends GenericMarshallers \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshal.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshal.scala new file mode 100644 index 0000000000..b246797a38 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshal.scala @@ -0,0 +1,75 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.{ ExecutionContext, Future } +import akka.http.scaladsl.model.HttpCharsets._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture._ + +object Marshal { + def apply[T](value: T): Marshal[T] = new Marshal(value) + + case class UnacceptableResponseContentTypeException(supported: Set[ContentType]) extends RuntimeException + + private class MarshallingWeight(val weight: Float, val marshal: () ⇒ HttpResponse) +} + +class Marshal[A](val value: A) { + /** + * Marshals `value` using the first available [[Marshalling]] for `A` and `B` provided by the given [[Marshaller]]. + * If the marshalling is flexible with regard to the used charset `UTF-8` is chosen. + */ + def to[B](implicit m: Marshaller[A, B], ec: ExecutionContext): Future[B] = + m(value).fast.map { + _.head match { + case Marshalling.WithFixedCharset(_, _, marshal) ⇒ marshal() + case Marshalling.WithOpenCharset(_, marshal) ⇒ marshal(HttpCharsets.`UTF-8`) + case Marshalling.Opaque(marshal) ⇒ marshal() + } + } + + /** + * Marshals `value` to an `HttpResponse` for the given `HttpRequest` with full content-negotiation. + */ + def toResponseFor(request: HttpRequest)(implicit m: ToResponseMarshaller[A], ec: ExecutionContext): Future[HttpResponse] = { + import akka.http.scaladsl.marshalling.Marshal._ + val mediaRanges = request.acceptedMediaRanges // cache for performance + val charsetRanges = request.acceptedCharsetRanges // cache for performance + def qValueMT(mediaType: MediaType) = request.qValueForMediaType(mediaType, mediaRanges) + def qValueCS(charset: HttpCharset) = request.qValueForCharset(charset, charsetRanges) + + m(value).fast.map { marshallings ⇒ + val defaultMarshallingWeight = new MarshallingWeight(0f, { () ⇒ + val supportedContentTypes = marshallings collect { + case Marshalling.WithFixedCharset(mt, cs, _) ⇒ ContentType(mt, cs) + case Marshalling.WithOpenCharset(mt, _) ⇒ ContentType(mt) + } + throw UnacceptableResponseContentTypeException(supportedContentTypes.toSet) + }) + def choose(acc: MarshallingWeight, mt: MediaType, cs: HttpCharset, marshal: () ⇒ HttpResponse) = { + val weight = math.min(qValueMT(mt), qValueCS(cs)) + if (weight > acc.weight) new MarshallingWeight(weight, marshal) else acc + } + val best = marshallings.foldLeft(defaultMarshallingWeight) { + case (acc, Marshalling.WithFixedCharset(mt, cs, marshal)) ⇒ + choose(acc, mt, cs, marshal) + case (acc, Marshalling.WithOpenCharset(mt, marshal)) ⇒ + def withCharset(cs: HttpCharset) = choose(acc, mt, cs, () ⇒ marshal(cs)) + // logic for choosing the charset adapted from http://tools.ietf.org/html/rfc7231#section-5.3.3 + if (qValueCS(`UTF-8`) == 1f) withCharset(`UTF-8`) // prefer UTF-8 if fully accepted + else charsetRanges match { + // pick the charset which the highest q-value (head of charsetRanges) if it isn't explicitly rejected + case (HttpCharsetRange.One(cs, qValue)) :: _ if qValue > 0f ⇒ withCharset(cs) + case _ ⇒ acc + } + + case (acc, Marshalling.Opaque(marshal)) ⇒ + if (acc.weight == 0f) new MarshallingWeight(Float.MinPositiveValue, marshal) else acc + } + best.marshal() + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshaller.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshaller.scala new file mode 100644 index 0000000000..b0c71ef260 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/Marshaller.scala @@ -0,0 +1,153 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.{ Future, ExecutionContext } +import scala.util.control.NonFatal +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ + +sealed abstract class Marshaller[-A, +B] { + + def apply(value: A)(implicit ec: ExecutionContext): Future[List[Marshalling[B]]] + + def map[C](f: B ⇒ C): Marshaller[A, C] = + Marshaller(implicit ec ⇒ value ⇒ this(value).fast map (_ map (_ map f))) + + /** + * Reuses this Marshaller's logic to produce a new Marshaller from another type `C` which overrides + * the produced [[ContentType]] with another one. + * Depending on whether the given [[ContentType]] has a defined charset or not and whether the underlying + * marshaller marshals with a fixed charset it can happen, that the wrapping becomes illegal. + * For example, a marshaller producing content encoded with UTF-16 cannot be wrapped with a [[ContentType]] + * that has a defined charset of UTF-8, since akka-http will never recode entities. + * If the wrapping is illegal the [[Future]] produced by the resulting marshaller will contain a [[RuntimeException]]. + */ + def wrap[C, D >: B](contentType: ContentType)(f: C ⇒ A)(implicit mto: MediaTypeOverrider[D]): Marshaller[C, D] = + wrapWithEC[C, D](contentType)(_ ⇒ f) + + /** + * Reuses this Marshaller's logic to produce a new Marshaller from another type `C` which overrides + * the produced [[ContentType]] with another one. + * Depending on whether the given [[ContentType]] has a defined charset or not and whether the underlying + * marshaller marshals with a fixed charset it can happen, that the wrapping becomes illegal. + * For example, a marshaller producing content encoded with UTF-16 cannot be wrapped with a [[ContentType]] + * that has a defined charset of UTF-8, since akka-http will never recode entities. + * If the wrapping is illegal the [[Future]] produced by the resulting marshaller will contain a [[RuntimeException]]. + */ + def wrapWithEC[C, D >: B](contentType: ContentType)(f: ExecutionContext ⇒ C ⇒ A)(implicit mto: MediaTypeOverrider[D]): Marshaller[C, D] = + Marshaller { implicit ec ⇒ + value ⇒ + import Marshalling._ + this(f(ec)(value)).fast map { + _ map { + case WithFixedCharset(_, cs, marshal) if contentType.hasOpenCharset || contentType.charset == cs ⇒ + WithFixedCharset(contentType.mediaType, cs, () ⇒ mto(marshal(), contentType.mediaType)) + case WithOpenCharset(_, marshal) if contentType.hasOpenCharset ⇒ + WithOpenCharset(contentType.mediaType, cs ⇒ mto(marshal(cs), contentType.mediaType)) + case WithOpenCharset(_, marshal) ⇒ + WithFixedCharset(contentType.mediaType, contentType.charset, () ⇒ mto(marshal(contentType.charset), contentType.mediaType)) + case Opaque(marshal) if contentType.definedCharset.isEmpty ⇒ Opaque(() ⇒ mto(marshal(), contentType.mediaType)) + case x ⇒ sys.error(s"Illegal marshaller wrapping. Marshalling `$x` cannot be wrapped with ContentType `$contentType`") + } + } + } + + def compose[C](f: C ⇒ A): Marshaller[C, B] = + Marshaller(implicit ec ⇒ c ⇒ apply(f(c))) + + def composeWithEC[C](f: ExecutionContext ⇒ C ⇒ A): Marshaller[C, B] = + Marshaller(implicit ec ⇒ c ⇒ apply(f(ec)(c))) +} + +object Marshaller + extends GenericMarshallers + with PredefinedToEntityMarshallers + with PredefinedToResponseMarshallers + with PredefinedToRequestMarshallers { + + /** + * Creates a [[Marshaller]] from the given function. + */ + def apply[A, B](f: ExecutionContext ⇒ A ⇒ Future[List[Marshalling[B]]]): Marshaller[A, B] = + new Marshaller[A, B] { + def apply(value: A)(implicit ec: ExecutionContext) = + try f(ec)(value) + catch { case NonFatal(e) ⇒ FastFuture.failed(e) } + } + + /** + * Helper for creating a [[Marshaller]] using the given function. + */ + def strict[A, B](f: A ⇒ Marshalling[B]): Marshaller[A, B] = + Marshaller { _ ⇒ a ⇒ FastFuture.successful(f(a) :: Nil) } + + /** + * Helper for creating a "super-marshaller" from a number of "sub-marshallers". + * Content-negotiation determines, which "sub-marshaller" eventually gets to do the job. + */ + def oneOf[A, B](marshallers: Marshaller[A, B]*): Marshaller[A, B] = + Marshaller { implicit ec ⇒ a ⇒ FastFuture.sequence(marshallers.map(_(a))).fast.map(_.flatten.toList) } + + /** + * Helper for creating a "super-marshaller" from a number of values and a function producing "sub-marshallers" + * from these values. Content-negotiation determines, which "sub-marshaller" eventually gets to do the job. + */ + def oneOf[T, A, B](values: T*)(f: T ⇒ Marshaller[A, B]): Marshaller[A, B] = + oneOf(values map f: _*) + + /** + * Helper for creating a synchronous [[Marshaller]] to content with a fixed charset from the given function. + */ + def withFixedCharset[A, B](mediaType: MediaType, charset: HttpCharset)(marshal: A ⇒ B): Marshaller[A, B] = + strict { value ⇒ Marshalling.WithFixedCharset(mediaType, charset, () ⇒ marshal(value)) } + + /** + * Helper for creating a synchronous [[Marshaller]] to content with a negotiable charset from the given function. + */ + def withOpenCharset[A, B](mediaType: MediaType)(marshal: (A, HttpCharset) ⇒ B): Marshaller[A, B] = + strict { value ⇒ Marshalling.WithOpenCharset(mediaType, charset ⇒ marshal(value, charset)) } + + /** + * Helper for creating a synchronous [[Marshaller]] to non-negotiable content from the given function. + */ + def opaque[A, B](marshal: A ⇒ B): Marshaller[A, B] = + strict { value ⇒ Marshalling.Opaque(() ⇒ marshal(value)) } +} + +/** + * Describes one possible option for marshalling a given value. + */ +sealed trait Marshalling[+A] { + def map[B](f: A ⇒ B): Marshalling[B] +} + +object Marshalling { + /** + * A Marshalling to a specific MediaType and charset. + */ + final case class WithFixedCharset[A](mediaType: MediaType, + charset: HttpCharset, + marshal: () ⇒ A) extends Marshalling[A] { + def map[B](f: A ⇒ B): WithFixedCharset[B] = copy(marshal = () ⇒ f(marshal())) + } + + /** + * A Marshalling to a specific MediaType and a potentially flexible charset. + */ + final case class WithOpenCharset[A](mediaType: MediaType, + marshal: HttpCharset ⇒ A) extends Marshalling[A] { + def map[B](f: A ⇒ B): WithOpenCharset[B] = copy(marshal = cs ⇒ f(marshal(cs))) + } + + /** + * A Marshalling to an unknown MediaType and charset. + * Circumvents content negotiation. + */ + final case class Opaque[A](marshal: () ⇒ A) extends Marshalling[A] { + def map[B](f: A ⇒ B): Opaque[B] = copy(marshal = () ⇒ f(marshal())) + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MediaTypeOverrider.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MediaTypeOverrider.scala new file mode 100644 index 0000000000..96b6833d1a --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MediaTypeOverrider.scala @@ -0,0 +1,30 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.collection.immutable +import akka.http.scaladsl.model._ + +sealed trait MediaTypeOverrider[T] { + def apply(value: T, mediaType: MediaType): T +} +object MediaTypeOverrider { + implicit def forEntity[T <: HttpEntity]: MediaTypeOverrider[T] = new MediaTypeOverrider[T] { + def apply(value: T, mediaType: MediaType) = + value.withContentType(value.contentType withMediaType mediaType).asInstanceOf[T] // can't be expressed in types + } + implicit def forHeadersAndEntity[T <: HttpEntity] = new MediaTypeOverrider[(immutable.Seq[HttpHeader], T)] { + def apply(value: (immutable.Seq[HttpHeader], T), mediaType: MediaType) = + value._1 -> value._2.withContentType(value._2.contentType withMediaType mediaType).asInstanceOf[T] + } + implicit val forResponse = new MediaTypeOverrider[HttpResponse] { + def apply(value: HttpResponse, mediaType: MediaType) = + value.mapEntity(forEntity(_: ResponseEntity, mediaType)) + } + implicit val forRequest = new MediaTypeOverrider[HttpRequest] { + def apply(value: HttpRequest, mediaType: MediaType) = + value.mapEntity(forEntity(_: RequestEntity, mediaType)) + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala new file mode 100644 index 0000000000..8649f12b92 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala @@ -0,0 +1,45 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.forkjoin.ThreadLocalRandom +import akka.parboiled2.util.Base64 +import akka.event.{ NoLogging, LoggingAdapter } +import akka.stream.scaladsl.FlattenStrategy +import akka.http.impl.engine.rendering.BodyPartRenderer +import akka.http.scaladsl.model._ + +trait MultipartMarshallers { + protected val multipartBoundaryRandom: java.util.Random = ThreadLocalRandom.current() + + /** + * Creates a new random 144-bit number and base64 encodes it (using a custom "safe" alphabet, yielding 24 characters). + */ + def randomBoundary: String = { + val array = new Array[Byte](18) + multipartBoundaryRandom.nextBytes(array) + Base64.custom.encodeToString(array, false) + } + + implicit def multipartMarshaller[T <: Multipart](implicit log: LoggingAdapter = NoLogging): ToEntityMarshaller[T] = + Marshaller strict { value ⇒ + val boundary = randomBoundary + val contentType = ContentType(value.mediaType withBoundary boundary) + Marshalling.WithOpenCharset(contentType.mediaType, { charset ⇒ + value match { + case x: Multipart.Strict ⇒ + val data = BodyPartRenderer.strict(x.strictParts, boundary, charset.nioCharset, partHeadersSizeHint = 128, log) + HttpEntity(contentType, data) + case _ ⇒ + val chunks = value.parts + .transform(() ⇒ BodyPartRenderer.streamed(boundary, charset.nioCharset, partHeadersSizeHint = 128, log)) + .flatten(FlattenStrategy.concat) + HttpEntity.Chunked(contentType, chunks) + } + }) + } +} + +object MultipartMarshallers extends MultipartMarshallers diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToEntityMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToEntityMarshallers.scala new file mode 100644 index 0000000000..da7cce221e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToEntityMarshallers.scala @@ -0,0 +1,67 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import java.nio.CharBuffer +import akka.http.impl.model.parser.CharacterClasses +import akka.http.scaladsl.model.MediaTypes._ +import akka.http.scaladsl.model._ +import akka.http.impl.util.StringRendering +import akka.util.ByteString + +trait PredefinedToEntityMarshallers extends MultipartMarshallers { + + implicit val ByteArrayMarshaller: ToEntityMarshaller[Array[Byte]] = byteArrayMarshaller(`application/octet-stream`) + def byteArrayMarshaller(mediaType: MediaType, charset: HttpCharset): ToEntityMarshaller[Array[Byte]] = { + val ct = ContentType(mediaType, charset) + Marshaller.withFixedCharset(ct.mediaType, ct.definedCharset.get) { bytes ⇒ HttpEntity(ct, bytes) } + } + def byteArrayMarshaller(mediaType: MediaType): ToEntityMarshaller[Array[Byte]] = { + val ct = ContentType(mediaType) + // since we don't want to recode we simply ignore the charset determined by content negotiation here + Marshaller.withOpenCharset(ct.mediaType) { (bytes, _) ⇒ HttpEntity(ct, bytes) } + } + + implicit val ByteStringMarshaller: ToEntityMarshaller[ByteString] = byteStringMarshaller(`application/octet-stream`) + def byteStringMarshaller(mediaType: MediaType, charset: HttpCharset): ToEntityMarshaller[ByteString] = { + val ct = ContentType(mediaType, charset) + Marshaller.withFixedCharset(ct.mediaType, ct.definedCharset.get) { bytes ⇒ HttpEntity(ct, bytes) } + } + def byteStringMarshaller(mediaType: MediaType): ToEntityMarshaller[ByteString] = { + val ct = ContentType(mediaType) + // since we don't want to recode we simply ignore the charset determined by content negotiation here + Marshaller.withOpenCharset(ct.mediaType) { (bytes, _) ⇒ HttpEntity(ct, bytes) } + } + + implicit val CharArrayMarshaller: ToEntityMarshaller[Array[Char]] = charArrayMarshaller(`text/plain`) + def charArrayMarshaller(mediaType: MediaType): ToEntityMarshaller[Array[Char]] = + Marshaller.withOpenCharset(mediaType) { (value, charset) ⇒ + if (value.length > 0) { + val charBuffer = CharBuffer.wrap(value) + val byteBuffer = charset.nioCharset.encode(charBuffer) + val array = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(array) + HttpEntity(ContentType(mediaType, charset), array) + } else HttpEntity.Empty + } + + implicit val StringMarshaller: ToEntityMarshaller[String] = stringMarshaller(`text/plain`) + def stringMarshaller(mediaType: MediaType): ToEntityMarshaller[String] = + Marshaller.withOpenCharset(mediaType) { (s, cs) ⇒ HttpEntity(ContentType(mediaType, cs), s) } + + implicit val FormDataMarshaller: ToEntityMarshaller[FormData] = + Marshaller.withOpenCharset(`application/x-www-form-urlencoded`) { (formData, charset) ⇒ + val query = Uri.Query(formData.fields: _*) + val string = UriRendering.renderQuery(new StringRendering, query, charset.nioCharset, CharacterClasses.unreserved).get + HttpEntity(ContentType(`application/x-www-form-urlencoded`, charset), string) + } + + implicit val HttpEntityMarshaller: ToEntityMarshaller[MessageEntity] = Marshaller strict { value ⇒ + Marshalling.WithFixedCharset(value.contentType.mediaType, value.contentType.charset, () ⇒ value) + } +} + +object PredefinedToEntityMarshallers extends PredefinedToEntityMarshallers + diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToRequestMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToRequestMarshallers.scala new file mode 100644 index 0000000000..6a5df4534c --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToRequestMarshallers.scala @@ -0,0 +1,29 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.collection.immutable +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture._ + +trait PredefinedToRequestMarshallers { + private type TRM[T] = ToRequestMarshaller[T] // brevity alias + + implicit val fromRequest: TRM[HttpRequest] = Marshaller.opaque(identity) + + implicit def fromUri: TRM[Uri] = + Marshaller strict { uri ⇒ Marshalling.Opaque(() ⇒ HttpRequest(uri = uri)) } + + implicit def fromMethodAndUriAndValue[S, T](implicit mt: ToEntityMarshaller[T]): TRM[(HttpMethod, Uri, T)] = + fromMethodAndUriAndHeadersAndValue[T] compose { case (m, u, v) ⇒ (m, u, Nil, v) } + + implicit def fromMethodAndUriAndHeadersAndValue[T](implicit mt: ToEntityMarshaller[T]): TRM[(HttpMethod, Uri, immutable.Seq[HttpHeader], T)] = + Marshaller(implicit ec ⇒ { + case (m, u, h, v) ⇒ mt(v).fast map (_ map (_ map (HttpRequest(m, u, h, _)))) + }) +} + +object PredefinedToRequestMarshallers extends PredefinedToRequestMarshallers + diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToResponseMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToResponseMarshallers.scala new file mode 100644 index 0000000000..8e344e77bb --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToResponseMarshallers.scala @@ -0,0 +1,49 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.collection.immutable +import akka.http.scaladsl.util.FastFuture._ +import akka.http.scaladsl.model.MediaTypes._ +import akka.http.scaladsl.model._ + +trait PredefinedToResponseMarshallers extends LowPriorityToResponseMarshallerImplicits { + + private type TRM[T] = ToResponseMarshaller[T] // brevity alias + + def fromToEntityMarshaller[T](status: StatusCode = StatusCodes.OK, + headers: immutable.Seq[HttpHeader] = Nil)( + implicit m: ToEntityMarshaller[T]): ToResponseMarshaller[T] = + fromStatusCodeAndHeadersAndValue compose (t ⇒ (status, headers, t)) + + implicit val fromResponse: TRM[HttpResponse] = Marshaller.opaque(identity) + + implicit val fromStatusCode: TRM[StatusCode] = + Marshaller.withOpenCharset(`text/plain`) { (status, charset) ⇒ + HttpResponse(status, entity = HttpEntity(ContentType(`text/plain`, charset), status.defaultMessage)) + } + + implicit def fromStatusCodeAndValue[S, T](implicit sConv: S ⇒ StatusCode, mt: ToEntityMarshaller[T]): TRM[(S, T)] = + fromStatusCodeAndHeadersAndValue[T] compose { case (status, value) ⇒ (sConv(status), Nil, value) } + + implicit def fromStatusCodeConvertibleAndHeadersAndT[S, T](implicit sConv: S ⇒ StatusCode, + mt: ToEntityMarshaller[T]): TRM[(S, immutable.Seq[HttpHeader], T)] = + fromStatusCodeAndHeadersAndValue[T] compose { case (status, headers, value) ⇒ (sConv(status), headers, value) } + + implicit def fromStatusCodeAndHeadersAndValue[T](implicit mt: ToEntityMarshaller[T]): TRM[(StatusCode, immutable.Seq[HttpHeader], T)] = + Marshaller(implicit ec ⇒ { + case (status, headers, value) ⇒ mt(value).fast map (_ map (_ map (HttpResponse(status, headers, _)))) + }) +} + +trait LowPriorityToResponseMarshallerImplicits { + implicit def liftMarshallerConversion[T](m: ToEntityMarshaller[T]): ToResponseMarshaller[T] = + liftMarshaller(m) + implicit def liftMarshaller[T](implicit m: ToEntityMarshaller[T]): ToResponseMarshaller[T] = + PredefinedToResponseMarshallers.fromToEntityMarshaller() +} + +object PredefinedToResponseMarshallers extends PredefinedToResponseMarshallers + diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/ToResponseMarshallable.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/ToResponseMarshallable.scala new file mode 100644 index 0000000000..40b4540bef --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/ToResponseMarshallable.scala @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.marshalling + +import scala.concurrent.{ Future, ExecutionContext } +import akka.http.scaladsl.model._ + +/** Something that can later be marshalled into a response */ +trait ToResponseMarshallable { + type T + def value: T + implicit def marshaller: ToResponseMarshaller[T] + + def apply(request: HttpRequest)(implicit ec: ExecutionContext): Future[HttpResponse] = + Marshal(value).toResponseFor(request) +} + +object ToResponseMarshallable { + implicit def apply[A](_value: A)(implicit _marshaller: ToResponseMarshaller[A]): ToResponseMarshallable = + new ToResponseMarshallable { + type T = A + def value: T = _value + def marshaller: ToResponseMarshaller[T] = _marshaller + } + + implicit val marshaller: ToResponseMarshaller[ToResponseMarshallable] = + Marshaller { implicit ec ⇒ marshallable ⇒ marshallable.marshaller(marshallable.value) } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/package.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/package.scala new file mode 100644 index 0000000000..ff2d855736 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/package.scala @@ -0,0 +1,15 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl + +import scala.collection.immutable +import akka.http.scaladsl.model._ + +package object marshalling { + type ToEntityMarshaller[T] = Marshaller[T, MessageEntity] + type ToHeadersAndEntityMarshaller[T] = Marshaller[T, (immutable.Seq[HttpHeader], MessageEntity)] + type ToResponseMarshaller[T] = Marshaller[T, HttpResponse] + type ToRequestMarshaller[T] = Marshaller[T, HttpRequest] +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Directive.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Directive.scala new file mode 100644 index 0000000000..f0627d65d1 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Directive.scala @@ -0,0 +1,161 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.collection.immutable +import akka.http.scaladsl.server.directives.RouteDirectives +import akka.http.scaladsl.server.util._ +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ + +/** + * A directive that provides a tuple of values of type `L` to create an inner route. + */ +abstract class Directive[L](implicit val ev: Tuple[L]) { + + /** + * Calls the inner route with a tuple of extracted values of type `L`. + * + * `tapply` is short for "tuple-apply". Usually, you will use the regular `apply` method instead, + * which is added by an implicit conversion (see `Directive.addDirectiveApply`). + */ + def tapply(f: L ⇒ Route): Route + + /** + * Joins two directives into one which runs the second directive if the first one rejects. + */ + def |[R >: L](that: Directive[R]): Directive[R] = + recover(rejections ⇒ directives.BasicDirectives.mapRejections(rejections ++ _) & that)(that.ev) + + /** + * Joins two directives into one which extracts the concatenation of its base directive extractions. + * NOTE: Extraction joining is an O(N) operation with N being the number of extractions on the right-side. + */ + def &(magnet: ConjunctionMagnet[L]): magnet.Out = magnet(this) + + /** + * Converts this directive into one which, instead of a tuple of type ``L``, creates an + * instance of type ``A`` (which is usually a case class). + */ + def as[A](constructor: ConstructFromTuple[L, A]): Directive1[A] = tmap(constructor) + + /** + * Maps over this directive using the given function, which can produce either a tuple or any other value + * (which will then we wrapped into a [[Tuple1]]). + */ + def tmap[R](f: L ⇒ R)(implicit tupler: Tupler[R]): Directive[tupler.Out] = + Directive[tupler.Out] { inner ⇒ tapply { values ⇒ inner(tupler(f(values))) } }(tupler.OutIsTuple) + + /** + * Flatmaps this directive using the given function. + */ + def tflatMap[R: Tuple](f: L ⇒ Directive[R]): Directive[R] = + Directive[R] { inner ⇒ tapply { values ⇒ f(values) tapply inner } } + + /** + * Creates a new [[Directive0]], which passes if the given predicate matches the current + * extractions or rejects with the given rejections. + */ + def trequire(predicate: L ⇒ Boolean, rejections: Rejection*): Directive0 = + tfilter(predicate, rejections: _*).tflatMap(_ ⇒ Directive.Empty) + + /** + * Creates a new directive of the same type, which passes if the given predicate matches the current + * extractions or rejects with the given rejections. + */ + def tfilter(predicate: L ⇒ Boolean, rejections: Rejection*): Directive[L] = + Directive[L] { inner ⇒ tapply { values ⇒ ctx ⇒ if (predicate(values)) inner(values)(ctx) else ctx.reject(rejections: _*) } } + + /** + * Creates a new directive that is able to recover from rejections that were produced by `this` Directive + * **before the inner route was applied**. + */ + def recover[R >: L: Tuple](recovery: immutable.Seq[Rejection] ⇒ Directive[R]): Directive[R] = + Directive[R] { inner ⇒ + ctx ⇒ + import ctx.executionContext + @volatile var rejectedFromInnerRoute = false + tapply({ list ⇒ c ⇒ rejectedFromInnerRoute = true; inner(list)(c) })(ctx).fast.flatMap { + case RouteResult.Rejected(rejections) if !rejectedFromInnerRoute ⇒ recovery(rejections).tapply(inner)(ctx) + case x ⇒ FastFuture.successful(x) + } + } + + /** + * Variant of `recover` that only recovers from rejections handled by the given PartialFunction. + */ + def recoverPF[R >: L: Tuple](recovery: PartialFunction[immutable.Seq[Rejection], Directive[R]]): Directive[R] = + recover { rejections ⇒ recovery.applyOrElse(rejections, (rejs: Seq[Rejection]) ⇒ RouteDirectives.reject(rejs: _*)) } +} + +object Directive { + + /** + * Constructs a directive from a function literal. + */ + def apply[T: Tuple](f: (T ⇒ Route) ⇒ Route): Directive[T] = + new Directive[T] { def tapply(inner: T ⇒ Route) = f(inner) } + + /** + * A Directive that always passes the request on to its inner route (i.e. does nothing). + */ + val Empty: Directive0 = Directive(_()) + + /** + * Adds `apply` to all Directives with 1 or more extractions, + * which allows specifying an n-ary function to receive the extractions instead of a Function1[TupleX, Route]. + */ + implicit def addDirectiveApply[L](directive: Directive[L])(implicit hac: ApplyConverter[L]): hac.In ⇒ Route = + f ⇒ directive.tapply(hac(f)) + + /** + * Adds `apply` to Directive0. Note: The `apply` parameter is call-by-name to ensure consistent execution behavior + * with the directives producing extractions. + */ + implicit def addByNameNullaryApply(directive: Directive0): (⇒ Route) ⇒ Route = + r ⇒ directive.tapply(_ ⇒ r) + + implicit class SingleValueModifiers[T](underlying: Directive1[T]) extends AnyRef { + def map[R](f: T ⇒ R)(implicit tupler: Tupler[R]): Directive[tupler.Out] = + underlying.tmap { case Tuple1(value) ⇒ f(value) } + + def flatMap[R: Tuple](f: T ⇒ Directive[R]): Directive[R] = + underlying.tflatMap { case Tuple1(value) ⇒ f(value) } + + def require(predicate: T ⇒ Boolean, rejections: Rejection*): Directive0 = + underlying.filter(predicate, rejections: _*).tflatMap(_ ⇒ Empty) + + def filter(predicate: T ⇒ Boolean, rejections: Rejection*): Directive1[T] = + underlying.tfilter({ case Tuple1(value) ⇒ predicate(value) }, rejections: _*) + } +} + +trait ConjunctionMagnet[L] { + type Out + def apply(underlying: Directive[L]): Out +} + +object ConjunctionMagnet { + implicit def fromDirective[L, R](other: Directive[R])(implicit join: TupleOps.Join[L, R]): ConjunctionMagnet[L] { type Out = Directive[join.Out] } = + new ConjunctionMagnet[L] { + type Out = Directive[join.Out] + def apply(underlying: Directive[L]) = + Directive[join.Out] { inner ⇒ + underlying.tapply { prefix ⇒ other.tapply { suffix ⇒ inner(join(prefix, suffix)) } } + }(Tuple.yes) // we know that join will only ever produce tuples + } + + implicit def fromStandardRoute[L](route: StandardRoute) = + new ConjunctionMagnet[L] { + type Out = StandardRoute + def apply(underlying: Directive[L]) = StandardRoute(underlying.tapply(_ ⇒ route)) + } + + implicit def fromRouteGenerator[T, R <: Route](generator: T ⇒ R) = + new ConjunctionMagnet[Unit] { + type Out = RouteGenerator[T] + def apply(underlying: Directive0) = value ⇒ underlying.tapply(_ ⇒ generator(value)) + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala new file mode 100644 index 0000000000..ece3426bbc --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import directives._ + +trait Directives extends RouteConcatenation + with BasicDirectives + with CacheConditionDirectives + with CookieDirectives + with DebuggingDirectives + with CodingDirectives + with ExecutionDirectives + with FileAndResourceDirectives + with FormFieldDirectives + with FutureDirectives + with HeaderDirectives + with HostDirectives + with MarshallingDirectives + with MethodDirectives + with MiscDirectives + with ParameterDirectives + with PathDirectives + with RangeDirectives + with RespondWithDirectives + with RouteDirectives + with SchemeDirectives + with SecurityDirectives + with WebsocketDirectives + +object Directives extends Directives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/ExceptionHandler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/ExceptionHandler.scala new file mode 100644 index 0000000000..950b87edc5 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/ExceptionHandler.scala @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.util.control.NonFatal +import akka.http.scaladsl.model._ +import StatusCodes._ + +trait ExceptionHandler extends ExceptionHandler.PF { + + /** + * Creates a new [[ExceptionHandler]] which uses the given one as fallback for this one. + */ + def withFallback(that: ExceptionHandler): ExceptionHandler + + /** + * "Seals" this handler by attaching a default handler as fallback if necessary. + */ + def seal(settings: RoutingSettings): ExceptionHandler +} + +object ExceptionHandler { + type PF = PartialFunction[Throwable, Route] + + implicit def apply(pf: PF): ExceptionHandler = apply(knownToBeSealed = false)(pf) + + private def apply(knownToBeSealed: Boolean)(pf: PF): ExceptionHandler = + new ExceptionHandler { + def isDefinedAt(error: Throwable) = pf.isDefinedAt(error) + def apply(error: Throwable) = pf(error) + def withFallback(that: ExceptionHandler): ExceptionHandler = + if (!knownToBeSealed) ExceptionHandler(knownToBeSealed = false)(this orElse that) else this + def seal(settings: RoutingSettings): ExceptionHandler = + if (!knownToBeSealed) ExceptionHandler(knownToBeSealed = true)(this orElse default(settings)) else this + } + + def default(settings: RoutingSettings): ExceptionHandler = + apply(knownToBeSealed = true) { + case IllegalRequestException(info, status) ⇒ ctx ⇒ { + ctx.log.warning("Illegal request {}\n\t{}\n\tCompleting with '{}' response", + ctx.request, info.formatPretty, status) + ctx.complete(status, info.format(settings.verboseErrorMessages)) + } + case NonFatal(e) ⇒ ctx ⇒ { + ctx.log.error(e, "Error during processing of request {}", ctx.request) + ctx.complete(InternalServerError) + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/PathMatcher.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/PathMatcher.scala new file mode 100644 index 0000000000..ddf579c41a --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/PathMatcher.scala @@ -0,0 +1,475 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import java.util.UUID +import scala.util.matching.Regex +import scala.annotation.tailrec +import akka.http.scaladsl.server.util.Tuple +import akka.http.scaladsl.server.util.TupleOps._ +import akka.http.scaladsl.common.NameOptionReceptacle +import akka.http.scaladsl.model.Uri.Path +import akka.http.impl.util._ + +/** + * A PathMatcher tries to match a prefix of a given string and returns either a PathMatcher.Matched instance + * if matched, otherwise PathMatcher.Unmatched. + */ +abstract class PathMatcher[L](implicit val ev: Tuple[L]) extends (Path ⇒ PathMatcher.Matching[L]) { self ⇒ + import PathMatcher._ + + def / : PathMatcher[L] = this ~ PathMatchers.Slash + + def /[R](other: PathMatcher[R])(implicit join: Join[L, R]): PathMatcher[join.Out] = + this ~ PathMatchers.Slash ~ other + + def |[R >: L: Tuple](other: PathMatcher[_ <: R]): PathMatcher[R] = + new PathMatcher[R] { + def apply(path: Path) = self(path) orElse other(path) + } + + def ~[R](other: PathMatcher[R])(implicit join: Join[L, R]): PathMatcher[join.Out] = { + implicit def joinProducesTuple = Tuple.yes[join.Out] + transform(_.andThen((restL, valuesL) ⇒ other(restL).map(join(valuesL, _)))) + } + + def unary_!(): PathMatcher0 = + new PathMatcher[Unit] { + def apply(path: Path) = if (self(path) eq Unmatched) Matched(path, ()) else Unmatched + } + + def transform[R: Tuple](f: Matching[L] ⇒ Matching[R]): PathMatcher[R] = + new PathMatcher[R] { def apply(path: Path) = f(self(path)) } + + def tmap[R: Tuple](f: L ⇒ R): PathMatcher[R] = transform(_.map(f)) + + def tflatMap[R: Tuple](f: L ⇒ Option[R]): PathMatcher[R] = transform(_.flatMap(f)) + + /** + * Same as ``repeat(min = count, max = count)``. + */ + def repeat(count: Int)(implicit lift: PathMatcher.Lift[L, List]): PathMatcher[lift.Out] = + repeat(min = count, max = count) + + /** + * Same as ``repeat(min = count, max = count, separator = separator)``. + */ + def repeat(count: Int, separator: PathMatcher0)(implicit lift: PathMatcher.Lift[L, List]): PathMatcher[lift.Out] = + repeat(min = count, max = count, separator = separator) + + /** + * Turns this ``PathMatcher`` into one that matches a number of times (with the given separator) + * and potentially extracts a ``List`` of the underlying matcher's extractions. + * If less than ``min`` applications of the underlying matcher have succeeded the produced matcher fails, + * otherwise it matches up to the given ``max`` number of applications. + * Note that it won't fail even if more than ``max`` applications could succeed! + * The "surplus" path elements will simply be left unmatched. + * + * The result type depends on the type of the underlying matcher: + * + * + * + * + * + * + *
If a ``matcher`` is of typethen ``matcher.repeat(...)`` is of type
``PathMatcher0````PathMatcher0``
``PathMatcher1[T]````PathMatcher1[List[T]``
``PathMatcher[L :Tuple]````PathMatcher[List[L]]``
+ */ + def repeat(min: Int, max: Int, separator: PathMatcher0 = PathMatchers.Neutral)(implicit lift: PathMatcher.Lift[L, List]): PathMatcher[lift.Out] = + new PathMatcher[lift.Out]()(lift.OutIsTuple) { + require(min >= 0, "`min` must be >= 0") + require(max >= min, "`max` must be >= `min`") + def apply(path: Path) = rec(path, 1) + def rec(path: Path, count: Int): Matching[lift.Out] = { + def done = if (count >= min) Matched(path, lift()) else Unmatched + if (count <= max) { + self(path) match { + case Matched(remaining, extractions) ⇒ + def done1 = if (count >= min) Matched(remaining, lift(extractions)) else Unmatched + separator(remaining) match { + case Matched(remaining2, _) ⇒ rec(remaining2, count + 1) match { + case Matched(`remaining2`, _) ⇒ done1 // we made no progress, so "go back" to before the separator + case Matched(rest, result) ⇒ Matched(rest, lift(extractions, result)) + case Unmatched ⇒ Unmatched + } + case Unmatched ⇒ done1 + } + case Unmatched ⇒ done + } + } else done + } + } +} + +object PathMatcher extends ImplicitPathMatcherConstruction { + sealed abstract class Matching[+L: Tuple] { + def map[R: Tuple](f: L ⇒ R): Matching[R] + def flatMap[R: Tuple](f: L ⇒ Option[R]): Matching[R] + def andThen[R: Tuple](f: (Path, L) ⇒ Matching[R]): Matching[R] + def orElse[R >: L](other: ⇒ Matching[R]): Matching[R] + } + case class Matched[L: Tuple](pathRest: Path, extractions: L) extends Matching[L] { + def map[R: Tuple](f: L ⇒ R) = Matched(pathRest, f(extractions)) + def flatMap[R: Tuple](f: L ⇒ Option[R]) = f(extractions) match { + case Some(valuesR) ⇒ Matched(pathRest, valuesR) + case None ⇒ Unmatched + } + def andThen[R: Tuple](f: (Path, L) ⇒ Matching[R]) = f(pathRest, extractions) + def orElse[R >: L](other: ⇒ Matching[R]) = this + } + object Matched { val Empty = Matched(Path.Empty, ()) } + case object Unmatched extends Matching[Nothing] { + def map[R: Tuple](f: Nothing ⇒ R) = this + def flatMap[R: Tuple](f: Nothing ⇒ Option[R]) = this + def andThen[R: Tuple](f: (Path, Nothing) ⇒ Matching[R]) = this + def orElse[R](other: ⇒ Matching[R]) = other + } + + /** + * Creates a PathMatcher that always matches, consumes nothing and extracts the given Tuple of values. + */ + def provide[L: Tuple](extractions: L): PathMatcher[L] = + new PathMatcher[L] { + def apply(path: Path) = Matched(path, extractions)(ev) + } + + /** + * Creates a PathMatcher that matches and consumes the given path prefix and extracts the given list of extractions. + * If the given prefix is empty the returned PathMatcher matches always and consumes nothing. + */ + def apply[L: Tuple](prefix: Path, extractions: L): PathMatcher[L] = + if (prefix.isEmpty) provide(extractions) + else new PathMatcher[L] { + def apply(path: Path) = + if (path startsWith prefix) Matched(path dropChars prefix.charCount, extractions)(ev) + else Unmatched + } + + def apply[L](magnet: PathMatcher[L]): PathMatcher[L] = magnet + + implicit class PathMatcher1Ops[T](matcher: PathMatcher1[T]) { + def map[R](f: T ⇒ R): PathMatcher1[R] = matcher.tmap { case Tuple1(e) ⇒ Tuple1(f(e)) } + def flatMap[R](f: T ⇒ Option[R]): PathMatcher1[R] = + matcher.tflatMap { case Tuple1(e) ⇒ f(e).map(x ⇒ Tuple1(x)) } + } + + implicit class EnhancedPathMatcher[L](underlying: PathMatcher[L]) { + def ?(implicit lift: PathMatcher.Lift[L, Option]): PathMatcher[lift.Out] = + new PathMatcher[lift.Out]()(lift.OutIsTuple) { + def apply(path: Path) = underlying(path) match { + case Matched(rest, extractions) ⇒ Matched(rest, lift(extractions)) + case Unmatched ⇒ Matched(path, lift()) + } + } + } + + sealed trait Lift[L, M[+_]] { + type Out + def OutIsTuple: Tuple[Out] + def apply(): Out + def apply(value: L): Out + def apply(value: L, more: Out): Out + } + object Lift extends LowLevelLiftImplicits { + trait MOps[M[+_]] { + def apply(): M[Nothing] + def apply[T](value: T): M[T] + def apply[T](value: T, more: M[T]): M[T] + } + object MOps { + implicit object OptionMOps extends MOps[Option] { + def apply(): Option[Nothing] = None + def apply[T](value: T): Option[T] = Some(value) + def apply[T](value: T, more: Option[T]): Option[T] = Some(value) + } + implicit object ListMOps extends MOps[List] { + def apply(): List[Nothing] = Nil + def apply[T](value: T): List[T] = value :: Nil + def apply[T](value: T, more: List[T]): List[T] = value :: more + } + } + implicit def liftUnit[M[+_]] = new Lift[Unit, M] { + type Out = Unit + def OutIsTuple = implicitly[Tuple[Out]] + def apply() = () + def apply(value: Unit) = value + def apply(value: Unit, more: Out) = value + } + implicit def liftSingleElement[A, M[+_]](implicit mops: MOps[M]) = new Lift[Tuple1[A], M] { + type Out = Tuple1[M[A]] + def OutIsTuple = implicitly[Tuple[Out]] + def apply() = Tuple1(mops()) + def apply(value: Tuple1[A]) = Tuple1(mops(value._1)) + def apply(value: Tuple1[A], more: Out) = Tuple1(mops(value._1, more._1)) + } + + } + + trait LowLevelLiftImplicits { + import Lift._ + implicit def default[T, M[+_]](implicit mops: MOps[M]) = new Lift[T, M] { + type Out = Tuple1[M[T]] + def OutIsTuple = implicitly[Tuple[Out]] + def apply() = Tuple1(mops()) + def apply(value: T) = Tuple1(mops(value)) + def apply(value: T, more: Out) = Tuple1(mops(value, more._1)) + } + } +} + +trait ImplicitPathMatcherConstruction { + import PathMatcher._ + + /** + * Creates a PathMatcher that consumes (a prefix of) the first path segment + * (if the path begins with a segment) and extracts a given value. + */ + implicit def stringExtractionPair2PathMatcher[T](tuple: (String, T)): PathMatcher1[T] = + PathMatcher(tuple._1 :: Path.Empty, Tuple1(tuple._2)) + + /** + * Creates a PathMatcher that consumes (a prefix of) the first path segment + * (if the path begins with a segment). + */ + implicit def segmentStringToPathMatcher(segment: String): PathMatcher0 = + PathMatcher(segment :: Path.Empty, ()) + + implicit def stringNameOptionReceptacle2PathMatcher(nr: NameOptionReceptacle[String]): PathMatcher0 = + PathMatcher(nr.name).? + + /** + * Creates a PathMatcher that consumes (a prefix of) the first path segment + * if the path begins with a segment (a prefix of) which matches the given regex. + * Extracts either the complete match (if the regex doesn't contain a capture group) or + * the capture group (if the regex contains exactly one). + * If the regex contains more than one capture group the method throws an IllegalArgumentException. + */ + implicit def regex2PathMatcher(regex: Regex): PathMatcher1[String] = regex.groupCount match { + case 0 ⇒ new PathMatcher1[String] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) ⇒ regex findPrefixOf segment match { + case Some(m) ⇒ Matched(segment.substring(m.length) :: tail, Tuple1(m)) + case None ⇒ Unmatched + } + case _ ⇒ Unmatched + } + } + case 1 ⇒ new PathMatcher1[String] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) ⇒ regex findPrefixMatchOf segment match { + case Some(m) ⇒ Matched(segment.substring(m.end) :: tail, Tuple1(m.group(1))) + case None ⇒ Unmatched + } + case _ ⇒ Unmatched + } + } + case _ ⇒ throw new IllegalArgumentException("Path regex '" + regex.pattern.pattern + + "' must not contain more than one capturing group") + } + /** + * Creates a PathMatcher from the given Map of path segments (prefixes) to extracted values. + * If the unmatched path starts with a segment having one of the maps keys as a prefix + * the matcher consumes this path segment (prefix) and extracts the corresponding map value. + */ + implicit def valueMap2PathMatcher[T](valueMap: Map[String, T]): PathMatcher1[T] = + if (valueMap.isEmpty) PathMatchers.nothingMatcher + else valueMap.map { case (prefix, value) ⇒ stringExtractionPair2PathMatcher(prefix, value) }.reduceLeft(_ | _) +} + +trait PathMatchers { + import PathMatcher._ + + /** + * Converts a path string containing slashes into a PathMatcher that interprets slashes as + * path segment separators. + */ + def separateOnSlashes(string: String): PathMatcher0 = { + @tailrec def split(ix: Int = 0, matcher: PathMatcher0 = null): PathMatcher0 = { + val nextIx = string.indexOf('/', ix) + def append(m: PathMatcher0) = if (matcher eq null) m else matcher / m + if (nextIx < 0) append(string.substring(ix)) + else split(nextIx + 1, append(string.substring(ix, nextIx))) + } + split() + } + + /** + * A PathMatcher that matches a single slash character ('/'). + */ + object Slash extends PathMatcher0 { + def apply(path: Path) = path match { + case Path.Slash(tail) ⇒ Matched(tail, ()) + case _ ⇒ Unmatched + } + } + + /** + * A PathMatcher that matches the very end of the requests URI path. + */ + object PathEnd extends PathMatcher0 { + def apply(path: Path) = path match { + case Path.Empty ⇒ Matched.Empty + case _ ⇒ Unmatched + } + } + + /** + * A PathMatcher that matches and extracts the complete remaining, + * unmatched part of the request's URI path as an (encoded!) String. + * If you need access to the remaining unencoded elements of the path + * use the `RestPath` matcher! + */ + object Rest extends PathMatcher1[String] { + def apply(path: Path) = Matched(Path.Empty, Tuple1(path.toString)) + } + + /** + * A PathMatcher that matches and extracts the complete remaining, + * unmatched part of the request's URI path. + */ + object RestPath extends PathMatcher1[Path] { + def apply(path: Path) = Matched(Path.Empty, Tuple1(path)) + } + + /** + * A PathMatcher that efficiently matches a number of digits and extracts their (non-negative) Int value. + * The matcher will not match 0 digits or a sequence of digits that would represent an Int value larger + * than Int.MaxValue. + */ + object IntNumber extends NumberMatcher[Int](Int.MaxValue, 10) { + def fromChar(c: Char) = fromDecimalChar(c) + } + + /** + * A PathMatcher that efficiently matches a number of digits and extracts their (non-negative) Long value. + * The matcher will not match 0 digits or a sequence of digits that would represent an Long value larger + * than Long.MaxValue. + */ + object LongNumber extends NumberMatcher[Long](Long.MaxValue, 10) { + def fromChar(c: Char) = fromDecimalChar(c) + } + + /** + * A PathMatcher that efficiently matches a number of hex-digits and extracts their (non-negative) Int value. + * The matcher will not match 0 digits or a sequence of digits that would represent an Int value larger + * than Int.MaxValue. + */ + object HexIntNumber extends NumberMatcher[Int](Int.MaxValue, 16) { + def fromChar(c: Char) = fromHexChar(c) + } + + /** + * A PathMatcher that efficiently matches a number of hex-digits and extracts their (non-negative) Long value. + * The matcher will not match 0 digits or a sequence of digits that would represent an Long value larger + * than Long.MaxValue. + */ + object HexLongNumber extends NumberMatcher[Long](Long.MaxValue, 16) { + def fromChar(c: Char) = fromHexChar(c) + } + + // common implementation of Number matchers + abstract class NumberMatcher[@specialized(Int, Long) T](max: T, base: T)(implicit x: Integral[T]) + extends PathMatcher1[T] { + + import x._ // import implicit conversions for numeric operators + val minusOne = x.zero - x.one + val maxDivBase = max / base + + def apply(path: Path) = path match { + case Path.Segment(segment, tail) ⇒ + @tailrec def digits(ix: Int = 0, value: T = minusOne): Matching[Tuple1[T]] = { + val a = if (ix < segment.length) fromChar(segment charAt ix) else minusOne + if (a == minusOne) { + if (value == minusOne) Unmatched + else Matched(if (ix < segment.length) segment.substring(ix) :: tail else tail, Tuple1(value)) + } else { + if (value == minusOne) digits(ix + 1, a) + else if (value <= maxDivBase && value * base <= max - a) // protect from overflow + digits(ix + 1, value * base + a) + else Unmatched + } + } + digits() + + case _ ⇒ Unmatched + } + + def fromChar(c: Char): T + + def fromDecimalChar(c: Char): T = if ('0' <= c && c <= '9') x.fromInt(c - '0') else minusOne + + def fromHexChar(c: Char): T = + if ('0' <= c && c <= '9') x.fromInt(c - '0') else { + val cn = c | 0x20 // normalize to lowercase + if ('a' <= cn && cn <= 'f') x.fromInt(cn - 'a' + 10) else minusOne + } + } + + /** + * A PathMatcher that matches and extracts a Double value. The matched string representation is the pure decimal, + * optionally signed form of a double value, i.e. without exponent. + */ + val DoubleNumber: PathMatcher1[Double] = + PathMatcher("""[+-]?\d*\.?\d*""".r) flatMap { string ⇒ + try Some(java.lang.Double.parseDouble(string)) + catch { case _: NumberFormatException ⇒ None } + } + + /** + * A PathMatcher that matches and extracts a java.util.UUID instance. + */ + val JavaUUID: PathMatcher1[UUID] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string ⇒ + try Some(UUID.fromString(string)) + catch { case _: IllegalArgumentException ⇒ None } + } + + /** + * A PathMatcher that always matches, doesn't consume anything and extracts nothing. + * Serves mainly as a neutral element in PathMatcher composition. + */ + val Neutral: PathMatcher0 = PathMatcher.provide(()) + + /** + * A PathMatcher that matches if the unmatched path starts with a path segment. + * If so the path segment is extracted as a String. + */ + object Segment extends PathMatcher1[String] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) ⇒ Matched(tail, Tuple1(segment)) + case _ ⇒ Unmatched + } + } + + /** + * A PathMatcher that matches up to 128 remaining segments as a List[String]. + * This can also be no segments resulting in the empty list. + * If the path has a trailing slash this slash will *not* be matched. + */ + val Segments: PathMatcher1[List[String]] = Segments(min = 0, max = 128) + + /** + * A PathMatcher that matches the given number of path segments (separated by slashes) as a List[String]. + * If there are more than ``count`` segments present the remaining ones will be left unmatched. + * If the path has a trailing slash this slash will *not* be matched. + */ + def Segments(count: Int): PathMatcher1[List[String]] = Segment.repeat(count, separator = Slash) + + /** + * A PathMatcher that matches between ``min`` and ``max`` (both inclusively) path segments (separated by slashes) + * as a List[String]. If there are more than ``count`` segments present the remaining ones will be left unmatched. + * If the path has a trailing slash this slash will *not* be matched. + */ + def Segments(min: Int, max: Int): PathMatcher1[List[String]] = Segment.repeat(min, max, separator = Slash) + + /** + * A PathMatcher that never matches anything. + */ + def nothingMatcher[L: Tuple]: PathMatcher[L] = + new PathMatcher[L] { + def apply(p: Path) = Unmatched + } +} + +object PathMatchers extends PathMatchers diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala new file mode 100644 index 0000000000..f3e8f46cc3 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala @@ -0,0 +1,204 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.collection.immutable +import akka.http.scaladsl.model._ +import headers._ + +/** + * A rejection encapsulates a specific reason why a Route was not able to handle a request. Rejections are gathered + * up over the course of a Route evaluation and finally converted to [[spray.http.HttpResponse]]s by the + * `handleRejections` directive, if there was no way for the request to be completed. + */ +trait Rejection + +/** + * Rejection created by method filters. + * Signals that the request was rejected because the HTTP method is unsupported. + */ +case class MethodRejection(supported: HttpMethod) extends Rejection + +/** + * Rejection created by scheme filters. + * Signals that the request was rejected because the Uri scheme is unsupported. + */ +case class SchemeRejection(supported: String) extends Rejection + +/** + * Rejection created by parameter filters. + * Signals that the request was rejected because a query parameter was not found. + */ +case class MissingQueryParamRejection(parameterName: String) extends Rejection + +/** + * Rejection created by parameter filters. + * Signals that the request was rejected because a query parameter could not be interpreted. + */ +case class MalformedQueryParamRejection(parameterName: String, errorMsg: String, + cause: Option[Throwable] = None) extends Rejection + +/** + * Rejection created by form field filters. + * Signals that the request was rejected because a form field was not found. + */ +case class MissingFormFieldRejection(fieldName: String) extends Rejection + +/** + * Rejection created by form field filters. + * Signals that the request was rejected because a form field could not be interpreted. + */ +case class MalformedFormFieldRejection(fieldName: String, errorMsg: String, + cause: Option[Throwable] = None) extends Rejection + +/** + * Rejection created by header directives. + * Signals that the request was rejected because a required header could not be found. + */ +case class MissingHeaderRejection(headerName: String) extends Rejection + +/** + * Rejection created by header directives. + * Signals that the request was rejected because a header value is malformed. + */ +case class MalformedHeaderRejection(headerName: String, errorMsg: String, + cause: Option[Throwable] = None) extends Rejection + +/** + * Rejection created by unmarshallers. + * Signals that the request was rejected because the requests content-type is unsupported. + */ +case class UnsupportedRequestContentTypeRejection(supported: Set[ContentTypeRange]) extends Rejection + +/** + * Rejection created by decoding filters. + * Signals that the request was rejected because the requests content encoding is unsupported. + */ +case class UnsupportedRequestEncodingRejection(supported: HttpEncoding) extends Rejection + +/** + * Rejection created by range directives. + * Signals that the request was rejected because the requests contains only unsatisfiable ByteRanges. + * The actualEntityLength gives the client a hint to create satisfiable ByteRanges. + */ +case class UnsatisfiableRangeRejection(unsatisfiableRanges: Seq[ByteRange], actualEntityLength: Long) extends Rejection + +/** + * Rejection created by range directives. + * Signals that the request contains too many ranges. An irregular high number of ranges + * indicates a broken client or a denial of service attack. + */ +case class TooManyRangesRejection(maxRanges: Int) extends Rejection + +/** + * Rejection created by unmarshallers. + * Signals that the request was rejected because unmarshalling failed with an error that wasn't + * an `IllegalArgumentException`. Usually that means that the request content was not of the expected format. + * Note that semantic issues with the request content (e.g. because some parameter was out of range) + * will usually trigger a `ValidationRejection` instead. + */ +case class MalformedRequestContentRejection(message: String, cause: Option[Throwable] = None) extends Rejection + +/** + * Rejection created by unmarshallers. + * Signals that the request was rejected because an message body entity was expected but not supplied. + */ +case object RequestEntityExpectedRejection extends Rejection + +/** + * Rejection created by marshallers. + * Signals that the request was rejected because the service is not capable of producing a response entity whose + * content type is accepted by the client + */ +case class UnacceptedResponseContentTypeRejection(supported: Set[ContentType]) extends Rejection + +/** + * Rejection created by encoding filters. + * Signals that the request was rejected because the service is not capable of producing a response entity whose + * content encoding is accepted by the client + */ +case class UnacceptedResponseEncodingRejection(supported: Set[HttpEncoding]) extends Rejection +object UnacceptedResponseEncodingRejection { + def apply(supported: HttpEncoding): UnacceptedResponseEncodingRejection = UnacceptedResponseEncodingRejection(Set(supported)) +} + +/** + * Rejection created by an [[akka.http.scaladsl.server.authentication.HttpAuthenticator]]. + * Signals that the request was rejected because the user could not be authenticated. The reason for the rejection is + * specified in the cause. + */ +case class AuthenticationFailedRejection(cause: AuthenticationFailedRejection.Cause, + challenge: HttpChallenge) extends Rejection + +object AuthenticationFailedRejection { + /** + * Signals the cause of the failed authentication. + */ + sealed trait Cause + + /** + * Signals the cause of the rejecting was that the user could not be authenticated, because the `WWW-Authenticate` + * header was not supplied. + */ + case object CredentialsMissing extends Cause + + /** + * Signals the cause of the rejecting was that the user could not be authenticated, because the supplied credentials + * are invalid. + */ + case object CredentialsRejected extends Cause +} + +/** + * Rejection created by the 'authorize' directive. + * Signals that the request was rejected because the user is not authorized. + */ +case object AuthorizationFailedRejection extends Rejection + +/** + * Rejection created by the `cookie` directive. + * Signals that the request was rejected because a cookie was not found. + */ +case class MissingCookieRejection(cookieName: String) extends Rejection + +/** + * Rejection created when a websocket request was expected but none was found. + */ +case object ExpectedWebsocketRequestRejection extends Rejection + +/** + * Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions` + * thrown by domain model constructors (e.g. via `require`). + * It signals that an expected value was semantically invalid. + */ +case class ValidationRejection(message: String, cause: Option[Throwable] = None) extends Rejection + +/** + * A special Rejection that serves as a container for a transformation function on rejections. + * It is used by some directives to "cancel" rejections that are added by later directives of a similar type. + * + * Consider this route structure for example: + * + * put { reject(ValidationRejection("no") } ~ get { ... } + * + * If this structure is applied to a PUT request the list of rejections coming back contains three elements: + * + * 1. A ValidationRejection + * 2. A MethodRejection + * 3. A TransformationRejection holding a function filtering out the MethodRejection + * + * so that in the end the RejectionHandler will only see one rejection (the ValidationRejection), because the + * MethodRejection added by the ``get`` directive is cancelled by the ``put`` directive (since the HTTP method + * did indeed match eventually). + */ +case class TransformationRejection(transform: immutable.Seq[Rejection] ⇒ immutable.Seq[Rejection]) extends Rejection + +/** + * A Throwable wrapping a Rejection. + * Can be used for marshalling `Future[T]` or `Try[T]` instances, whose failure side is supposed to trigger a route + * rejection rather than an Exception that is handled by the nearest ExceptionHandler. + * (Custom marshallers can of course use it as well.) + */ +case class RejectionError(rejection: Rejection) extends RuntimeException diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala new file mode 100644 index 0000000000..2019949b7b --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.annotation.tailrec +import scala.reflect.ClassTag +import scala.collection.immutable +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model._ +import StatusCodes._ +import AuthenticationFailedRejection._ + +trait RejectionHandler extends (immutable.Seq[Rejection] ⇒ Option[Route]) { self ⇒ + import RejectionHandler._ + + /** + * Creates a new [[RejectionHandler]] which uses the given one as fallback for this one. + */ + def withFallback(that: RejectionHandler): RejectionHandler = + (this, that) match { + case (a: BuiltRejectionHandler, _) if a.isDefault ⇒ this // the default handler already handles everything + case (a: BuiltRejectionHandler, b: BuiltRejectionHandler) ⇒ + new BuiltRejectionHandler(a.cases ++ b.cases, a.notFound orElse b.notFound, b.isDefault) + case _ ⇒ new RejectionHandler { + def apply(rejections: immutable.Seq[Rejection]): Option[Route] = + self(rejections) orElse that(rejections) + } + } + + /** + * "Seals" this handler by attaching a default handler as fallback if necessary. + */ + def seal: RejectionHandler = + this match { + case x: BuiltRejectionHandler if x.isDefault ⇒ x + case _ ⇒ withFallback(default) + } +} + +object RejectionHandler { + + /** + * Creates a new [[RejectionHandler]] builder. + */ + def newBuilder(): Builder = new Builder(isDefault = false) + + final class Builder private[RejectionHandler] (isDefault: Boolean) { + private[this] val cases = new immutable.VectorBuilder[Handler] + private[this] var notFound: Option[Route] = None + + /** + * Handles a single [[Rejection]] with the given partial function. + */ + def handle(pf: PartialFunction[Rejection, Route]): this.type = { + cases += CaseHandler(pf) + this + } + + /** + * Handles several Rejections of the same type at the same time. + * The seq passed to the given function is guaranteed to be non-empty. + */ + def handleAll[T <: Rejection: ClassTag](f: immutable.Seq[T] ⇒ Route): this.type = { + val runtimeClass = implicitly[ClassTag[T]].runtimeClass + cases += TypeHandler[T](runtimeClass, f) + this + } + + /** + * Handles the special "not found" case using the given [[Route]]. + */ + def handleNotFound(route: Route): this.type = { + notFound = Some(route) + this + } + + def result(): RejectionHandler = + new BuiltRejectionHandler(cases.result(), notFound, isDefault) + } + + private sealed abstract class Handler + private final case class CaseHandler(pf: PartialFunction[Rejection, Route]) extends Handler + private final case class TypeHandler[T <: Rejection]( + runtimeClass: Class[_], f: immutable.Seq[T] ⇒ Route) extends Handler with PartialFunction[Rejection, T] { + def isDefinedAt(rejection: Rejection) = runtimeClass isInstance rejection + def apply(rejection: Rejection) = rejection.asInstanceOf[T] + } + + private class BuiltRejectionHandler(val cases: Vector[Handler], + val notFound: Option[Route], + val isDefault: Boolean) extends RejectionHandler { + def apply(rejections: immutable.Seq[Rejection]): Option[Route] = + if (rejections.nonEmpty) { + @tailrec def rec(ix: Int): Option[Route] = + if (ix < cases.length) { + cases(ix) match { + case CaseHandler(pf) ⇒ + val route = rejections collectFirst pf + if (route.isEmpty) rec(ix + 1) else route + case x @ TypeHandler(_, f) ⇒ + val rejs = rejections collect x + if (rejs.isEmpty) rec(ix + 1) else Some(f(rejs)) + } + } else None + rec(0) + } else notFound + } + + import Directives._ + + /** + * Creates a new default [[RejectionHandler]] instance. + */ + def default = + newBuilder() + .handleAll[SchemeRejection] { rejections ⇒ + val schemes = rejections.map(_.supported).mkString(", ") + complete(BadRequest, "Uri scheme not allowed, supported schemes: " + schemes) + } + .handleAll[MethodRejection] { rejections ⇒ + val (methods, names) = rejections.map(r ⇒ r.supported -> r.supported.name).unzip + complete(MethodNotAllowed, List(Allow(methods)), "HTTP method not allowed, supported methods: " + names.mkString(", ")) + } + .handle { + case AuthorizationFailedRejection ⇒ + complete(Forbidden, "The supplied authentication is not authorized to access this resource") + } + .handle { + case MalformedFormFieldRejection(name, msg, _) ⇒ + complete(BadRequest, "The form field '" + name + "' was malformed:\n" + msg) + } + .handle { + case MalformedHeaderRejection(headerName, msg, _) ⇒ + complete(BadRequest, s"The value of HTTP header '$headerName' was malformed:\n" + msg) + } + .handle { + case MalformedQueryParamRejection(name, msg, _) ⇒ + complete(BadRequest, "The query parameter '" + name + "' was malformed:\n" + msg) + } + .handle { + case MalformedRequestContentRejection(msg, _) ⇒ + complete(BadRequest, "The request content was malformed:\n" + msg) + } + .handle { + case MissingCookieRejection(cookieName) ⇒ + complete(BadRequest, "Request is missing required cookie '" + cookieName + '\'') + } + .handle { + case MissingFormFieldRejection(fieldName) ⇒ + complete(BadRequest, "Request is missing required form field '" + fieldName + '\'') + } + .handle { + case MissingHeaderRejection(headerName) ⇒ + complete(BadRequest, "Request is missing required HTTP header '" + headerName + '\'') + } + .handle { + case MissingQueryParamRejection(paramName) ⇒ + complete(NotFound, "Request is missing required query parameter '" + paramName + '\'') + } + .handle { + case RequestEntityExpectedRejection ⇒ + complete(BadRequest, "Request entity expected but not supplied") + } + .handle { + case TooManyRangesRejection(_) ⇒ + complete(RequestedRangeNotSatisfiable, "Request contains too many ranges.") + } + .handle { + case UnsatisfiableRangeRejection(unsatisfiableRanges, actualEntityLength) ⇒ + complete(RequestedRangeNotSatisfiable, List(`Content-Range`(ContentRange.Unsatisfiable(actualEntityLength))), + unsatisfiableRanges.mkString("None of the following requested Ranges were satisfiable:\n", "\n", "")) + } + .handleAll[AuthenticationFailedRejection] { rejections ⇒ + val rejectionMessage = rejections.head.cause match { + case CredentialsMissing ⇒ "The resource requires authentication, which was not supplied with the request" + case CredentialsRejected ⇒ "The supplied authentication is invalid" + } + // Multiple challenges per WWW-Authenticate header are allowed per spec, + // however, it seems many browsers will ignore all challenges but the first. + // Therefore, multiple WWW-Authenticate headers are rendered, instead. + // + // See https://code.google.com/p/chromium/issues/detail?id=103220 + // and https://bugzilla.mozilla.org/show_bug.cgi?id=669675 + val authenticateHeaders = rejections.map(r ⇒ `WWW-Authenticate`(r.challenge)) + complete(Unauthorized, authenticateHeaders, rejectionMessage) + } + .handleAll[UnacceptedResponseContentTypeRejection] { rejections ⇒ + val supported = rejections.flatMap(_.supported) + complete(NotAcceptable, "Resource representation is only available with these Content-Types:\n" + + supported.map(_.value).mkString("\n")) + } + .handleAll[UnacceptedResponseEncodingRejection] { rejections ⇒ + val supported = rejections.flatMap(_.supported) + complete(NotAcceptable, "Resource representation is only available with these Content-Encodings:\n" + + supported.map(_.value).mkString("\n")) + } + .handleAll[UnsupportedRequestContentTypeRejection] { rejections ⇒ + val supported = rejections.flatMap(_.supported).mkString(" or ") + complete(UnsupportedMediaType, "The request's Content-Type is not supported. Expected:\n" + supported) + } + .handleAll[UnsupportedRequestEncodingRejection] { rejections ⇒ + val supported = rejections.map(_.supported.value).mkString(" or ") + complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported) + } + .handle { case ExpectedWebsocketRequestRejection ⇒ complete(BadRequest, "Expected Websocket Upgrade request") } + .handle { case ValidationRejection(msg, _) ⇒ complete(BadRequest, msg) } + .handle { case x ⇒ sys.error("Unhandled rejection: " + x) } + .handleNotFound { complete(NotFound, "The requested resource could not be found.") } + .result() + + /** + * Filters out all TransformationRejections from the given sequence and applies them (in order) to the + * remaining rejections. + */ + def applyTransformations(rejections: immutable.Seq[Rejection]): immutable.Seq[Rejection] = { + val (transformations, rest) = rejections.partition(_.isInstanceOf[TransformationRejection]) + (rest.distinct /: transformations.asInstanceOf[Seq[TransformationRejection]]) { + case (remaining, transformation) ⇒ transformation.transform(remaining) + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContext.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContext.scala new file mode 100644 index 0000000000..2f77d5dd54 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContext.scala @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.concurrent.{ Future, ExecutionContext } +import akka.stream.FlowMaterializer +import akka.event.LoggingAdapter +import akka.http.scaladsl.marshalling.ToResponseMarshallable +import akka.http.scaladsl.model._ + +/** + * Immutable object encapsulating the context of an [[akka.http.scaladsl.model.HttpRequest]] + * as it flows through a akka-http Route structure. + */ +trait RequestContext { + + /** The request this context represents. Modelled as a ``val`` so as to enable an ``import ctx.request._``. */ + val request: HttpRequest + + /** The unmatched path of this context. Modelled as a ``val`` so as to enable an ``import ctx.unmatchedPath._``. */ + val unmatchedPath: Uri.Path + + /** + * The default ExecutionContext to be used for scheduling asynchronous logic related to this request. + */ + implicit def executionContext: ExecutionContext + + /** + * The default FlowMaterializer. + */ + implicit def flowMaterializer: FlowMaterializer + + /** + * The default LoggingAdapter to be used for logging messages related to this request. + */ + def log: LoggingAdapter + + /** + * The default RoutingSettings to be used for configuring directives. + */ + def settings: RoutingSettings + + /** + * Returns a copy of this context with the given fields updated. + */ + def reconfigure( + executionContext: ExecutionContext = executionContext, + flowMaterializer: FlowMaterializer = flowMaterializer, + log: LoggingAdapter = log, + settings: RoutingSettings = settings): RequestContext + + /** + * Completes the request with the given ToResponseMarshallable. + */ + def complete(obj: ToResponseMarshallable): Future[RouteResult] + + /** + * Rejects the request with the given rejections. + */ + def reject(rejections: Rejection*): Future[RouteResult] + + /** + * Bubbles the given error up the response chain where it is dealt with by the closest `handleExceptions` + * directive and its ``ExceptionHandler``, unless the error is a ``RejectionError``. In this case the + * wrapped rejection is unpacked and "executed". + */ + def fail(error: Throwable): Future[RouteResult] + + /** + * Returns a copy of this context with the new HttpRequest. + */ + def withRequest(req: HttpRequest): RequestContext + + /** + * Returns a copy of this context with the new HttpRequest. + */ + def withExecutionContext(ec: ExecutionContext): RequestContext + + /** + * Returns a copy of this context with the new HttpRequest. + */ + def withFlowMaterializer(materializer: FlowMaterializer): RequestContext + + /** + * Returns a copy of this context with the new LoggingAdapter. + */ + def withLog(log: LoggingAdapter): RequestContext + + /** + * Returns a copy of this context with the new RoutingSettings. + */ + def withSettings(settings: RoutingSettings): RequestContext + + /** + * Returns a copy of this context with the HttpRequest transformed by the given function. + */ + def mapRequest(f: HttpRequest ⇒ HttpRequest): RequestContext + + /** + * Returns a copy of this context with the unmatched path updated to the given one. + */ + def withUnmatchedPath(path: Uri.Path): RequestContext + + /** + * Returns a copy of this context with the unmatchedPath transformed by the given function. + */ + def mapUnmatchedPath(f: Uri.Path ⇒ Uri.Path): RequestContext + + /** + * Removes a potentially existing Accept header from the request headers. + */ + def withAcceptAll: RequestContext +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContextImpl.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContextImpl.scala new file mode 100644 index 0000000000..44027d57b2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RequestContextImpl.scala @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.concurrent.{ Future, ExecutionContext } +import akka.stream.FlowMaterializer +import akka.event.LoggingAdapter +import akka.http.scaladsl.marshalling.{ Marshal, ToResponseMarshallable } +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ + +/** + * INTERNAL API + */ +private[http] class RequestContextImpl( + val request: HttpRequest, + val unmatchedPath: Uri.Path, + val executionContext: ExecutionContext, + val flowMaterializer: FlowMaterializer, + val log: LoggingAdapter, + val settings: RoutingSettings) extends RequestContext { + + def this(request: HttpRequest, log: LoggingAdapter, settings: RoutingSettings)(implicit ec: ExecutionContext, materializer: FlowMaterializer) = + this(request, request.uri.path, ec, materializer, log, settings) + + def reconfigure(executionContext: ExecutionContext, flowMaterializer: FlowMaterializer, log: LoggingAdapter, settings: RoutingSettings): RequestContext = + copy(executionContext = executionContext, flowMaterializer = flowMaterializer, log = log, settings = settings) + + override def complete(trm: ToResponseMarshallable): Future[RouteResult] = + trm(request)(executionContext) + .fast.map(res ⇒ RouteResult.Complete(res))(executionContext) + .fast.recover { + case Marshal.UnacceptableResponseContentTypeException(supported) ⇒ + RouteResult.Rejected(UnacceptedResponseContentTypeRejection(supported) :: Nil) + case RejectionError(rej) ⇒ RouteResult.Rejected(rej :: Nil) + }(executionContext) + + override def reject(rejections: Rejection*): Future[RouteResult] = + FastFuture.successful(RouteResult.Rejected(rejections.toList)) + + override def fail(error: Throwable): Future[RouteResult] = + FastFuture.failed(error) + + override def withRequest(request: HttpRequest): RequestContext = + if (request != this.request) copy(request = request) else this + + override def withExecutionContext(executionContext: ExecutionContext): RequestContext = + if (executionContext != this.executionContext) copy(executionContext = executionContext) else this + + override def withFlowMaterializer(flowMaterializer: FlowMaterializer): RequestContext = + if (flowMaterializer != this.flowMaterializer) copy(flowMaterializer = flowMaterializer) else this + + override def withLog(log: LoggingAdapter): RequestContext = + if (log != this.log) copy(log = log) else this + + override def withSettings(settings: RoutingSettings): RequestContext = + if (settings != this.settings) copy(settings = settings) else this + + override def mapRequest(f: HttpRequest ⇒ HttpRequest): RequestContext = + copy(request = f(request)) + + override def withUnmatchedPath(path: Uri.Path): RequestContext = + if (path != unmatchedPath) copy(unmatchedPath = path) else this + + override def mapUnmatchedPath(f: Uri.Path ⇒ Uri.Path): RequestContext = + copy(unmatchedPath = f(unmatchedPath)) + + override def withAcceptAll: RequestContext = request.header[headers.Accept] match { + case Some(accept @ headers.Accept(ranges)) if !accept.acceptsAll ⇒ + mapRequest(_.mapHeaders(_.map { + case `accept` ⇒ + val acceptAll = + if (ranges.exists(_.isWildcard)) ranges.map(r ⇒ if (r.isWildcard) MediaRanges.`*/*;q=MIN` else r) + else ranges :+ MediaRanges.`*/*;q=MIN` + accept.copy(mediaRanges = acceptAll) + case x ⇒ x + })) + case _ ⇒ this + } + + private def copy(request: HttpRequest = request, + unmatchedPath: Uri.Path = unmatchedPath, + executionContext: ExecutionContext = executionContext, + flowMaterializer: FlowMaterializer = flowMaterializer, + log: LoggingAdapter = log, + settings: RoutingSettings = settings) = + new RequestContextImpl(request, unmatchedPath, executionContext, flowMaterializer, log, settings) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Route.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Route.scala new file mode 100644 index 0000000000..a2642ec9f4 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Route.scala @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.concurrent.Future +import akka.stream.scaladsl.Flow +import akka.http.scaladsl.model.{ HttpRequest, HttpResponse } +import akka.http.scaladsl.util.FastFuture._ + +object Route { + + /** + * Helper for constructing a Route from a function literal. + */ + def apply(f: Route): Route = f + + /** + * "Seals" a route by wrapping it with exception handling and rejection conversion. + */ + def seal(route: Route)(implicit setup: RoutingSetup): Route = { + import directives.ExecutionDirectives._ + import setup._ + handleExceptions(exceptionHandler.seal(setup.settings)) { + handleRejections(rejectionHandler.seal) { + route + } + } + } + + /** + * Turns a `Route` into a server flow. + */ + def handlerFlow(route: Route)(implicit setup: RoutingSetup): Flow[HttpRequest, HttpResponse, Unit] = + Flow[HttpRequest].mapAsync(1)(asyncHandler(route)) + + /** + * Turns a `Route` into an async handler function. + */ + def asyncHandler(route: Route)(implicit setup: RoutingSetup): HttpRequest ⇒ Future[HttpResponse] = { + import setup._ + val sealedRoute = seal(route) + request ⇒ + sealedRoute(new RequestContextImpl(request, routingLog.requestLog(request), setup.settings)).fast.map { + case RouteResult.Complete(response) ⇒ response + case RouteResult.Rejected(rejected) ⇒ throw new IllegalStateException(s"Unhandled rejections '$rejected', unsealed RejectionHandler?!") + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RouteConcatenation.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RouteConcatenation.scala new file mode 100644 index 0000000000..f8ece7deed --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RouteConcatenation.scala @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ + +trait RouteConcatenation { + + implicit def enhanceRouteWithConcatenation(route: Route) = + new RouteConcatenation.RouteWithConcatenation(route: Route) +} + +object RouteConcatenation extends RouteConcatenation { + + class RouteWithConcatenation(route: Route) { + /** + * Returns a Route that chains two Routes. If the first Route rejects the request the second route is given a + * chance to act upon the request. + */ + def ~(other: Route): Route = { ctx ⇒ + import ctx.executionContext + route(ctx).fast.flatMap { + case x: RouteResult.Complete ⇒ FastFuture.successful(x) + case RouteResult.Rejected(outerRejections) ⇒ + other(ctx).fast.map { + case x: RouteResult.Complete ⇒ x + case RouteResult.Rejected(innerRejections) ⇒ RouteResult.Rejected(outerRejections ++ innerRejections) + } + } + } + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RouteResult.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RouteResult.scala new file mode 100644 index 0000000000..832e9345f2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RouteResult.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.collection.immutable +import akka.stream.scaladsl.Flow +import akka.http.scaladsl.model.{ HttpRequest, HttpResponse } + +/** + * The result of handling a request. + * + * As a user you typically don't create RouteResult instances directly. + * Instead, use the methods on the [[RequestContext]] to achieve the desired effect. + */ +sealed trait RouteResult + +object RouteResult { + final case class Complete(response: HttpResponse) extends RouteResult + final case class Rejected(rejections: immutable.Seq[Rejection]) extends RouteResult + + implicit def route2HandlerFlow(route: Route)(implicit setup: RoutingSetup): Flow[HttpRequest, HttpResponse, Unit] = + Route.handlerFlow(route) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSettings.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSettings.scala new file mode 100644 index 0000000000..7fbd156524 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSettings.scala @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import com.typesafe.config.Config +import akka.actor.ActorRefFactory +import akka.http.impl.util._ + +case class RoutingSettings( + verboseErrorMessages: Boolean, + fileGetConditional: Boolean, + renderVanityFooter: Boolean, + rangeCountLimit: Int, + rangeCoalescingThreshold: Long, + decodeMaxBytesPerChunk: Int, + fileIODispatcher: String) + +object RoutingSettings extends SettingsCompanion[RoutingSettings]("akka.http.routing") { + def fromSubConfig(c: Config) = apply( + c getBoolean "verbose-error-messages", + c getBoolean "file-get-conditional", + c getBoolean "render-vanity-footer", + c getInt "range-count-limit", + c getBytes "range-coalescing-threshold", + c getIntBytes "decode-max-bytes-per-chunk", + c getString "file-io-dispatcher") + + implicit def default(implicit refFactory: ActorRefFactory) = + apply(actorSystem) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSetup.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSetup.scala new file mode 100644 index 0000000000..ba931ce550 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RoutingSetup.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import scala.concurrent.ExecutionContext +import akka.event.LoggingAdapter +import akka.actor.{ ActorSystem, ActorContext } +import akka.stream.FlowMaterializer +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.HttpRequest + +/** + * Provides a ``RoutingSetup`` for a given connection. + */ +trait RoutingSetupProvider { + def apply(connection: Http.IncomingConnection): RoutingSetup +} +object RoutingSetupProvider { + def apply(f: Http.IncomingConnection ⇒ RoutingSetup): RoutingSetupProvider = + new RoutingSetupProvider { + def apply(connection: Http.IncomingConnection) = f(connection) + } + + implicit def default(implicit setup: RoutingSetup) = RoutingSetupProvider(_ ⇒ setup) +} + +/** + * Provides all dependencies required for route execution. + */ +class RoutingSetup( + val settings: RoutingSettings, + val exceptionHandler: ExceptionHandler, + val rejectionHandler: RejectionHandler, + val executionContext: ExecutionContext, + val flowMaterializer: FlowMaterializer, + val routingLog: RoutingLog) { + + // enable `import setup._` to properly bring implicits in scope + implicit def executor: ExecutionContext = executionContext + implicit def materializer: FlowMaterializer = flowMaterializer +} + +object RoutingSetup { + implicit def apply(implicit routingSettings: RoutingSettings, + exceptionHandler: ExceptionHandler = null, + rejectionHandler: RejectionHandler = null, + executionContext: ExecutionContext = null, + flowMaterializer: FlowMaterializer, + routingLog: RoutingLog): RoutingSetup = + new RoutingSetup( + routingSettings, + if (exceptionHandler ne null) exceptionHandler else ExceptionHandler.default(routingSettings), + if (rejectionHandler ne null) rejectionHandler else RejectionHandler.default, + if (executionContext ne null) executionContext else flowMaterializer.executionContext, + flowMaterializer, + routingLog) +} + +trait RoutingLog { + def log: LoggingAdapter + def requestLog(request: HttpRequest): LoggingAdapter +} + +object RoutingLog extends LowerPriorityRoutingLogImplicits { + def apply(defaultLog: LoggingAdapter): RoutingLog = + new RoutingLog { + def log = defaultLog + def requestLog(request: HttpRequest) = defaultLog + } + + implicit def fromActorContext(implicit ac: ActorContext): RoutingLog = RoutingLog(ac.system.log) +} +sealed abstract class LowerPriorityRoutingLogImplicits { + implicit def fromActorSystem(implicit system: ActorSystem): RoutingLog = RoutingLog(system.log) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/StandardRoute.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/StandardRoute.scala new file mode 100644 index 0000000000..800b468bc5 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/StandardRoute.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server + +import akka.http.scaladsl.server.util.Tuple + +/** + * A Route that can be implicitly converted into a Directive (fitting any signature). + */ +abstract class StandardRoute extends Route { + def toDirective[L: Tuple]: Directive[L] = StandardRoute.toDirective(this) +} + +object StandardRoute { + def apply(route: Route): StandardRoute = route match { + case x: StandardRoute ⇒ x + case x ⇒ new StandardRoute { def apply(ctx: RequestContext) = x(ctx) } + } + + /** + * Converts the StandardRoute into a directive that never passes the request to its inner route + * (and always returns its underlying route). + */ + implicit def toDirective[L: Tuple](route: StandardRoute) = Directive[L] { _ ⇒ route } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/BasicDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/BasicDirectives.scala new file mode 100644 index 0000000000..ec8c58597b --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/BasicDirectives.scala @@ -0,0 +1,199 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.{ Future, ExecutionContext } +import scala.collection.immutable +import akka.event.LoggingAdapter +import akka.stream.FlowMaterializer +import akka.http.scaladsl.server.util.Tuple +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture._ + +trait BasicDirectives { + + def mapInnerRoute(f: Route ⇒ Route): Directive0 = + Directive { inner ⇒ f(inner(())) } + + def mapRequestContext(f: RequestContext ⇒ RequestContext): Directive0 = + mapInnerRoute { inner ⇒ ctx ⇒ inner(f(ctx)) } + + def mapRequest(f: HttpRequest ⇒ HttpRequest): Directive0 = + mapRequestContext(_ mapRequest f) + + def mapRouteResultFuture(f: Future[RouteResult] ⇒ Future[RouteResult]): Directive0 = + Directive { inner ⇒ ctx ⇒ f(inner(())(ctx)) } + + def mapRouteResult(f: RouteResult ⇒ RouteResult): Directive0 = + Directive { inner ⇒ ctx ⇒ inner(())(ctx).fast.map(f)(ctx.executionContext) } + + def mapRouteResultWith(f: RouteResult ⇒ Future[RouteResult]): Directive0 = + Directive { inner ⇒ ctx ⇒ inner(())(ctx).fast.flatMap(f)(ctx.executionContext) } + + def mapRouteResultPF(f: PartialFunction[RouteResult, RouteResult]): Directive0 = + mapRouteResult(f.applyOrElse(_, akka.http.impl.util.identityFunc[RouteResult])) + + def mapRouteResultWithPF(f: PartialFunction[RouteResult, Future[RouteResult]]): Directive0 = + mapRouteResultWith(f.applyOrElse(_, FastFuture.successful[RouteResult])) + + def recoverRejections(f: immutable.Seq[Rejection] ⇒ RouteResult): Directive0 = + mapRouteResultPF { case RouteResult.Rejected(rejections) ⇒ f(rejections) } + + def recoverRejectionsWith(f: immutable.Seq[Rejection] ⇒ Future[RouteResult]): Directive0 = + mapRouteResultWithPF { case RouteResult.Rejected(rejections) ⇒ f(rejections) } + + def mapRejections(f: immutable.Seq[Rejection] ⇒ immutable.Seq[Rejection]): Directive0 = + recoverRejections(rejections ⇒ RouteResult.Rejected(f(rejections))) + + def mapResponse(f: HttpResponse ⇒ HttpResponse): Directive0 = + mapRouteResultPF { case RouteResult.Complete(response) ⇒ RouteResult.Complete(f(response)) } + + def mapResponseEntity(f: ResponseEntity ⇒ ResponseEntity): Directive0 = + mapResponse(_ mapEntity f) + + def mapResponseHeaders(f: immutable.Seq[HttpHeader] ⇒ immutable.Seq[HttpHeader]): Directive0 = + mapResponse(_ mapHeaders f) + + /** + * A Directive0 that always passes the request on to its inner route + * (i.e. does nothing with the request or the response). + */ + def pass: Directive0 = Directive.Empty + + /** + * Injects the given value into a directive. + */ + def provide[T](value: T): Directive1[T] = tprovide(Tuple1(value)) + + /** + * Injects the given values into a directive. + */ + def tprovide[L: Tuple](values: L): Directive[L] = + Directive { _(values) } + + /** + * Extracts a single value using the given function. + */ + def extract[T](f: RequestContext ⇒ T): Directive1[T] = + textract(ctx ⇒ Tuple1(f(ctx))) + + /** + * Extracts a number of values using the given function. + */ + def textract[L: Tuple](f: RequestContext ⇒ L): Directive[L] = + Directive { inner ⇒ ctx ⇒ inner(f(ctx))(ctx) } + + /** + * Adds a TransformationRejection cancelling all rejections equal to the given one + * to the list of rejections potentially coming back from the inner route. + */ + def cancelRejection(rejection: Rejection): Directive0 = + cancelRejections(_ == rejection) + + /** + * Adds a TransformationRejection cancelling all rejections of one of the given classes + * to the list of rejections potentially coming back from the inner route. + */ + def cancelRejections(classes: Class[_]*): Directive0 = + cancelRejections(r ⇒ classes.exists(_ isInstance r)) + + /** + * Adds a TransformationRejection cancelling all rejections for which the given filter function returns true + * to the list of rejections potentially coming back from the inner route. + */ + def cancelRejections(cancelFilter: Rejection ⇒ Boolean): Directive0 = + mapRejections(_ :+ TransformationRejection(_ filterNot cancelFilter)) + + /** + * Transforms the unmatchedPath of the RequestContext using the given function. + */ + def mapUnmatchedPath(f: Uri.Path ⇒ Uri.Path): Directive0 = + mapRequestContext(_ mapUnmatchedPath f) + + /** + * Extracts the unmatched path from the RequestContext. + */ + def extractUnmatchedPath: Directive1[Uri.Path] = BasicDirectives._extractUnmatchedPath + + /** + * Extracts the complete request. + */ + def extractRequest: Directive1[HttpRequest] = BasicDirectives._extractRequest + + /** + * Extracts the complete request URI. + */ + def extractUri: Directive1[Uri] = BasicDirectives._extractUri + + /** + * Runs its inner route with the given alternative [[ExecutionContext]]. + */ + def withExecutionContext(ec: ExecutionContext): Directive0 = + mapRequestContext(_ withExecutionContext ec) + + /** + * Extracts the [[ExecutionContext]] from the [[RequestContext]]. + */ + def extractExecutionContext: Directive1[ExecutionContext] = BasicDirectives._extractExecutionContext + + /** + * Runs its inner route with the given alternative [[FlowMaterializer]]. + */ + def withFlowMaterializer(materializer: FlowMaterializer): Directive0 = + mapRequestContext(_ withFlowMaterializer materializer) + + /** + * Extracts the [[FlowMaterializer]] from the [[RequestContext]]. + */ + def extractFlowMaterializer: Directive1[FlowMaterializer] = BasicDirectives._extractFlowMaterializer + + /** + * Runs its inner route with the given alternative [[LoggingAdapter]]. + */ + def withLog(log: LoggingAdapter): Directive0 = + mapRequestContext(_ withLog log) + + /** + * Extracts the [[LoggingAdapter]] from the [[RequestContext]]. + */ + def extractLog: Directive1[LoggingAdapter] = + BasicDirectives._extractLog + + /** + * Runs its inner route with the given alternative [[RoutingSettings]]. + */ + def withSettings(settings: RoutingSettings): Directive0 = + mapRequestContext(_ withSettings settings) + + /** + * Runs the inner route with settings mapped by the given function. + */ + def mapSettings(f: RoutingSettings ⇒ RoutingSettings): Directive0 = + mapRequestContext(ctx ⇒ ctx.withSettings(f(ctx.settings))) + + /** + * Extracts the [[RoutingSettings]] from the [[RequestContext]]. + */ + def extractSettings: Directive1[RoutingSettings] = + BasicDirectives._extractSettings + + /** + * Extracts the [[RequestContext]] itself. + */ + def extractRequestContext: Directive1[RequestContext] = BasicDirectives._extractRequestContext +} + +object BasicDirectives extends BasicDirectives { + private val _extractUnmatchedPath: Directive1[Uri.Path] = extract(_.unmatchedPath) + private val _extractRequest: Directive1[HttpRequest] = extract(_.request) + private val _extractUri: Directive1[Uri] = extract(_.request.uri) + private val _extractExecutionContext: Directive1[ExecutionContext] = extract(_.executionContext) + private val _extractFlowMaterializer: Directive1[FlowMaterializer] = extract(_.flowMaterializer) + private val _extractLog: Directive1[LoggingAdapter] = extract(_.log) + private val _extractSettings: Directive1[RoutingSettings] = extract(_.settings) + private val _extractRequestContext: Directive1[RequestContext] = extract(akka.http.impl.util.identityFunc) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CacheConditionDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CacheConditionDirectives.scala new file mode 100644 index 0000000000..e9330b2df3 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CacheConditionDirectives.scala @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.DateTime +import headers._ +import HttpMethods._ +import StatusCodes._ +import EntityTag._ + +trait CacheConditionDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Wraps its inner route with support for Conditional Requests as defined + * by http://tools.ietf.org/html/rfc7232 + * + * In particular the algorithm defined by http://tools.ietf.org/html/rfc7232#section-6 + * is implemented by this directive. + * + * Note: if you want to combine this directive with `withRangeSupport(...)` you need to put + * it on the *outside* of the `withRangeSupport(...)` directive, i.e. `withRangeSupport(...)` + * must be on a deeper level in your route structure in order to function correctly. + */ + def conditional(eTag: EntityTag): Directive0 = conditional(Some(eTag), None) + + /** + * Wraps its inner route with support for Conditional Requests as defined + * by http://tools.ietf.org/html/rfc7232 + * + * In particular the algorithm defined by http://tools.ietf.org/html/rfc7232#section-6 + * is implemented by this directive. + * + * Note: if you want to combine this directive with `withRangeSupport(...)` you need to put + * it on the *outside* of the `withRangeSupport(...)` directive, i.e. `withRangeSupport(...)` + * must be on a deeper level in your route structure in order to function correctly. + */ + def conditional(lastModified: DateTime): Directive0 = conditional(None, Some(lastModified)) + + /** + * Wraps its inner route with support for Conditional Requests as defined + * by http://tools.ietf.org/html/rfc7232 + * + * In particular the algorithm defined by http://tools.ietf.org/html/rfc7232#section-6 + * is implemented by this directive. + * + * Note: if you want to combine this directive with `withRangeSupport(...)` you need to put + * it on the *outside* of the `withRangeSupport(...)` directive, i.e. `withRangeSupport(...)` + * must be on a deeper level in your route structure in order to function correctly. + */ + def conditional(eTag: EntityTag, lastModified: DateTime): Directive0 = conditional(Some(eTag), Some(lastModified)) + + /** + * Wraps its inner route with support for Conditional Requests as defined + * by http://tools.ietf.org/html/rfc7232 + * + * In particular the algorithm defined by http://tools.ietf.org/html/rfc7232#section-6 + * is implemented by this directive. + * + * Note: if you want to combine this directive with `withRangeSupport(...)` you need to put + * it on the *outside* of the `withRangeSupport(...)` directive, i.e. `withRangeSupport(...)` + * must be on a deeper level in your route structure in order to function correctly. + */ + def conditional(eTag: Option[EntityTag], lastModified: Option[DateTime]): Directive0 = { + def addResponseHeaders: Directive0 = + mapResponse(_.withDefaultHeaders(eTag.map(ETag(_)).toList ++ lastModified.map(`Last-Modified`(_)).toList)) + + // TODO: also handle Cache-Control and Vary + def complete304(): Route = addResponseHeaders(complete(HttpResponse(NotModified))) + def complete412(): Route = _.complete(PreconditionFailed) + + extractRequest.flatMap { request ⇒ + import request._ + mapInnerRoute { route ⇒ + def innerRouteWithRangeHeaderFilteredOut: Route = + (mapRequest(_.mapHeaders(_.filterNot(_.isInstanceOf[Range]))) & + addResponseHeaders)(route) + + def isGetOrHead = method == HEAD || method == GET + def unmodified(ifModifiedSince: DateTime) = + lastModified.get <= ifModifiedSince && ifModifiedSince.clicks < System.currentTimeMillis() + + def step1(): Route = + header[`If-Match`] match { + case Some(`If-Match`(im)) if eTag.isDefined ⇒ + if (matchesRange(eTag.get, im, weakComparison = false)) step3() else complete412() + case None ⇒ step2() + } + def step2(): Route = + header[`If-Unmodified-Since`] match { + case Some(`If-Unmodified-Since`(ius)) if lastModified.isDefined && !unmodified(ius) ⇒ complete412() + case _ ⇒ step3() + } + def step3(): Route = + header[`If-None-Match`] match { + case Some(`If-None-Match`(inm)) if eTag.isDefined ⇒ + if (!matchesRange(eTag.get, inm, weakComparison = true)) step5() + else if (isGetOrHead) complete304() else complete412() + case None ⇒ step4() + } + def step4(): Route = + if (isGetOrHead) { + header[`If-Modified-Since`] match { + case Some(`If-Modified-Since`(ims)) if lastModified.isDefined && unmodified(ims) ⇒ complete304() + case _ ⇒ step5() + } + } else step5() + def step5(): Route = + if (method == GET && header[Range].isDefined) + header[`If-Range`] match { + case Some(`If-Range`(Left(tag))) if eTag.isDefined && !matches(eTag.get, tag, weakComparison = false) ⇒ + innerRouteWithRangeHeaderFilteredOut + case Some(`If-Range`(Right(ims))) if lastModified.isDefined && !unmodified(ims) ⇒ + innerRouteWithRangeHeaderFilteredOut + case _ ⇒ step6() + } + else step6() + def step6(): Route = addResponseHeaders(route) + + step1() + } + } + } +} + +object CacheConditionDirectives extends CacheConditionDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala new file mode 100644 index 0000000000..e7a937fc2e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CodingDirectives.scala @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.annotation.tailrec +import scala.collection.immutable +import scala.util.control.NonFatal +import akka.http.scaladsl.model.headers.{ HttpEncodings, `Accept-Encoding`, HttpEncoding, HttpEncodingRange } +import akka.http.scaladsl.model._ +import akka.http.scaladsl.coding._ +import akka.http.impl.util._ + +trait CodingDirectives { + import BasicDirectives._ + import MiscDirectives._ + import RouteDirectives._ + import CodingDirectives._ + + // encoding + + /** + * Rejects the request with an UnacceptedResponseEncodingRejection + * if the given encoding is not accepted for the response. + */ + def responseEncodingAccepted(encoding: HttpEncoding): Directive0 = + extract(_.request.isEncodingAccepted(encoding)) + .flatMap(if (_) pass else reject(UnacceptedResponseEncodingRejection(Set(encoding)))) + + /** + * Encodes the response with the encoding that is requested by the client with the `Accept- + * Encoding` header. The response encoding is determined by the rules specified in + * http://tools.ietf.org/html/rfc7231#section-5.3.4. + * + * If the `Accept-Encoding` header is missing or empty or specifies an encoding other than + * identity, gzip or deflate then no encoding is used. + */ + def encodeResponse: Directive0 = + encodeResponseWith(NoCoding, Gzip, Deflate) + + /** + * Encodes the response with the encoding that is requested by the client with the `Accept- + * Encoding` header. The response encoding is determined by the rules specified in + * http://tools.ietf.org/html/rfc7231#section-5.3.4. + * + * If the `Accept-Encoding` header is missing then the response is encoded using the `first` + * encoder. + * + * If the `Accept-Encoding` header is empty and `NoCoding` is part of the encoders then no + * response encoding is used. Otherwise the request is rejected. + */ + def encodeResponseWith(first: Encoder, more: Encoder*): Directive0 = + _encodeResponse(immutable.Seq(first +: more: _*)) + + // decoding + + /** + * Decodes the incoming request using the given Decoder. + * If the request encoding doesn't match the request is rejected with an `UnsupportedRequestEncodingRejection`. + */ + def decodeRequestWith(decoder: Decoder): Directive0 = { + def applyDecoder = + extractSettings flatMap { settings ⇒ + val effectiveDecoder = decoder.withMaxBytesPerChunk(settings.decodeMaxBytesPerChunk) + mapRequest { request ⇒ + effectiveDecoder.decode(request).mapEntity(StreamUtils.mapEntityError { + case NonFatal(e) ⇒ + IllegalRequestException( + StatusCodes.BadRequest, + ErrorInfo("The request's encoding is corrupt", e.getMessage)) + }) + } + } + + requestEntityEmpty | ( + requestEncodedWith(decoder.encoding) & + applyDecoder & + cancelRejections(classOf[UnsupportedRequestEncodingRejection])) + } + + /** + * Rejects the request with an UnsupportedRequestEncodingRejection if its encoding doesn't match the given one. + */ + def requestEncodedWith(encoding: HttpEncoding): Directive0 = + extract(_.request.encoding).flatMap { + case `encoding` ⇒ pass + case _ ⇒ reject(UnsupportedRequestEncodingRejection(encoding)) + } + + /** + * Decodes the incoming request if it is encoded with one of the given + * encoders. If the request encoding doesn't match one of the given encoders + * the request is rejected with an `UnsupportedRequestEncodingRejection`. + * If no decoders are given the default encoders (``Gzip``, ``Deflate``, ``NoCoding``) are used. + */ + def decodeRequestWith(decoders: Decoder*): Directive0 = + theseOrDefault(decoders).map(decodeRequestWith).reduce(_ | _) + + /** + * Decompresses the incoming request if it is ``gzip`` or ``deflate`` compressed. + * Uncompressed requests are passed through untouched. + * If the request encoded with another encoding the request is rejected with an `UnsupportedRequestEncodingRejection`. + */ + def decodeRequest: Directive0 = + decodeRequestWith(DefaultCoders: _*) +} + +object CodingDirectives extends CodingDirectives { + val DefaultCoders: immutable.Seq[Coder] = immutable.Seq(Gzip, Deflate, NoCoding) + + def theseOrDefault[T >: Coder](these: Seq[T]): Seq[T] = if (these.isEmpty) DefaultCoders else these + + import BasicDirectives._ + import HeaderDirectives._ + import RouteDirectives._ + + private def _encodeResponse(encoders: immutable.Seq[Encoder]): Directive0 = + optionalHeaderValueByType(classOf[`Accept-Encoding`]) flatMap { accept ⇒ + val acceptedEncoder = accept match { + case None ⇒ + // use first defined encoder when Accept-Encoding is missing + encoders.headOption + case Some(`Accept-Encoding`(encodings)) ⇒ + // provide fallback to identity + val withIdentity = + if (encodings.exists { + case HttpEncodingRange.One(HttpEncodings.identity, _) ⇒ true + case _ ⇒ false + }) encodings + else encodings :+ HttpEncodings.`identity;q=MIN` + // sort client-accepted encodings by q-Value (and orig. order) and find first matching encoder + @tailrec def find(encodings: List[HttpEncodingRange]): Option[Encoder] = encodings match { + case encoding :: rest ⇒ + encoders.find(e ⇒ encoding.matches(e.encoding)) match { + case None ⇒ find(rest) + case x ⇒ x + } + case _ ⇒ None + } + find(withIdentity.sortBy(e ⇒ (-e.qValue, withIdentity.indexOf(e))).toList) + } + acceptedEncoder match { + case Some(encoder) ⇒ mapResponse(encoder.encode(_)) + case _ ⇒ reject(UnacceptedResponseEncodingRejection(encoders.map(_.encoding).toSet)) + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CookieDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CookieDirectives.scala new file mode 100644 index 0000000000..85e480c2d1 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CookieDirectives.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import headers._ +import akka.http.impl.util._ + +trait CookieDirectives { + import HeaderDirectives._ + import RespondWithDirectives._ + import RouteDirectives._ + + /** + * Extracts an HttpCookie with the given name. If the cookie is not present the + * request is rejected with a respective [[MissingCookieRejection]]. + */ + def cookie(name: String): Directive1[HttpCookie] = + headerValue(findCookie(name)) | reject(MissingCookieRejection(name)) + + /** + * Extracts an HttpCookie with the given name. + * If the cookie is not present a value of `None` is extracted. + */ + def optionalCookie(name: String): Directive1[Option[HttpCookie]] = + optionalHeaderValue(findCookie(name)) + + private def findCookie(name: String): HttpHeader ⇒ Option[HttpCookie] = { + case Cookie(cookies) ⇒ cookies.find(_.name == name) + case _ ⇒ None + } + + /** + * Adds a Set-Cookie header with the given cookies to all responses of its inner route. + */ + def setCookie(first: HttpCookie, more: HttpCookie*): Directive0 = + respondWithHeaders((first :: more.toList).map(`Set-Cookie`(_))) + + /** + * Adds a Set-Cookie header expiring the given cookies to all responses of its inner route. + */ + def deleteCookie(first: HttpCookie, more: HttpCookie*): Directive0 = + respondWithHeaders((first :: more.toList).map { c ⇒ + `Set-Cookie`(c.copy(content = "deleted", expires = Some(DateTime.MinValue))) + }) + + /** + * Adds a Set-Cookie header expiring the given cookie to all responses of its inner route. + */ + def deleteCookie(name: String, domain: String = "", path: String = ""): Directive0 = + deleteCookie(HttpCookie(name, "", domain = domain.toOption, path = path.toOption)) + +} + +object CookieDirectives extends CookieDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/DebuggingDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/DebuggingDirectives.scala new file mode 100644 index 0000000000..a7fe5d6809 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/DebuggingDirectives.scala @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.event.Logging._ +import akka.event.LoggingAdapter +import akka.http.scaladsl.model._ + +trait DebuggingDirectives { + import BasicDirectives._ + + def logRequest(magnet: LoggingMagnet[HttpRequest ⇒ Unit]): Directive0 = + extractRequestContext.flatMap { ctx ⇒ + magnet.f(ctx.log)(ctx.request) + pass + } + + def logResult(magnet: LoggingMagnet[RouteResult ⇒ Unit]): Directive0 = + extractRequestContext.flatMap { ctx ⇒ + mapRouteResult { result ⇒ + magnet.f(ctx.log)(result) + result + } + } + + def logRequestResult(magnet: LoggingMagnet[HttpRequest ⇒ RouteResult ⇒ Unit]): Directive0 = + extractRequestContext.flatMap { ctx ⇒ + val logResult = magnet.f(ctx.log)(ctx.request) + mapRouteResult { result ⇒ + logResult(result) + result + } + } +} + +object DebuggingDirectives extends DebuggingDirectives + +case class LoggingMagnet[T](f: LoggingAdapter ⇒ T) // # logging-magnet + +object LoggingMagnet { + implicit def forMessageFromMarker[T](marker: String) = // # message-magnets + forMessageFromMarkerAndLevel[T](marker -> DebugLevel) + + implicit def forMessageFromMarkerAndLevel[T](markerAndLevel: (String, LogLevel)) = // # message-magnets + forMessageFromFullShow[T] { + val (marker, level) = markerAndLevel + Message ⇒ LogEntry(Message, marker, level) + } + + implicit def forMessageFromShow[T](show: T ⇒ String) = // # message-magnets + forMessageFromFullShow[T](msg ⇒ LogEntry(show(msg), DebugLevel)) + + implicit def forMessageFromFullShow[T](show: T ⇒ LogEntry): LoggingMagnet[T ⇒ Unit] = // # message-magnets + LoggingMagnet(log ⇒ show(_).logTo(log)) + + implicit def forRequestResponseFromMarker(marker: String) = // # request-response-magnets + forRequestResponseFromMarkerAndLevel(marker -> DebugLevel) + + implicit def forRequestResponseFromMarkerAndLevel(markerAndLevel: (String, LogLevel)) = // # request-response-magnets + forRequestResponseFromFullShow { + val (marker, level) = markerAndLevel + request ⇒ response ⇒ Some( + LogEntry("Response for\n Request : " + request + "\n Response: " + response, marker, level)) + } + + implicit def forRequestResponseFromFullShow(show: HttpRequest ⇒ RouteResult ⇒ Option[LogEntry]): LoggingMagnet[HttpRequest ⇒ RouteResult ⇒ Unit] = // # request-response-magnets + LoggingMagnet { log ⇒ + request ⇒ + val showResult = show(request) + result ⇒ showResult(result).foreach(_.logTo(log)) + } +} + +case class LogEntry(obj: Any, level: LogLevel = DebugLevel) { + def logTo(log: LoggingAdapter): Unit = { + log.log(level, obj.toString) + } +} + +object LogEntry { + def apply(obj: Any, marker: String, level: LogLevel): LogEntry = + LogEntry(if (marker.isEmpty) obj else marker + ": " + obj, level) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ExecutionDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ExecutionDirectives.scala new file mode 100644 index 0000000000..ec356af339 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ExecutionDirectives.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.collection.immutable +import scala.concurrent.Future +import scala.util.control.NonFatal +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ + +trait ExecutionDirectives { + import BasicDirectives._ + + /** + * Transforms exceptions thrown during evaluation of its inner route using the given + * [[akka.http.scaladsl.server.ExceptionHandler]]. + */ + def handleExceptions(handler: ExceptionHandler): Directive0 = + Directive { innerRouteBuilder ⇒ + ctx ⇒ + import ctx.executionContext + def handleException: PartialFunction[Throwable, Future[RouteResult]] = + handler andThen (_(ctx.withAcceptAll)) + try innerRouteBuilder(())(ctx).fast.recoverWith(handleException) + catch { + case NonFatal(e) ⇒ handleException.applyOrElse[Throwable, Future[RouteResult]](e, throw _) + } + } + + /** + * Transforms rejections produced by its inner route using the given + * [[akka.http.scaladsl.server.RejectionHandler]]. + */ + def handleRejections(handler: RejectionHandler): Directive0 = + extractRequestContext flatMap { ctx ⇒ + val maxIterations = 8 + // allow for up to `maxIterations` nested rejections from RejectionHandler before bailing out + def handle(rejections: immutable.Seq[Rejection], originalRejections: immutable.Seq[Rejection], iterationsLeft: Int = maxIterations): Future[RouteResult] = + if (iterationsLeft > 0) { + handler(rejections) match { + case Some(route) ⇒ recoverRejectionsWith(handle(_, originalRejections, iterationsLeft - 1))(route)(ctx.withAcceptAll) + case None ⇒ FastFuture.successful(RouteResult.Rejected(rejections)) + } + } else + sys.error(s"Rejection handler still produced new rejections after $maxIterations iterations. " + + s"Is there an infinite handler cycle? Initial rejections: $originalRejections final rejections: $rejections") + + recoverRejectionsWith { rejections ⇒ + val transformed = RejectionHandler.applyTransformations(rejections) + handle(transformed, transformed) + } + } +} + +object ExecutionDirectives extends ExecutionDirectives \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectives.scala new file mode 100644 index 0000000000..cbe4efeae0 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectives.scala @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import java.io.{ File, FileInputStream } +import java.net.URL + +import scala.annotation.tailrec +import akka.actor.ActorSystem +import akka.event.LoggingAdapter +import akka.http.scaladsl.marshalling.{ Marshaller, ToEntityMarshaller } +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.impl.util._ + +trait FileAndResourceDirectives { + import CacheConditionDirectives._ + import MethodDirectives._ + import FileAndResourceDirectives._ + import RouteDirectives._ + import BasicDirectives._ + import RouteConcatenation._ + import RangeDirectives._ + + /** + * Completes GET requests with the content of the given file. The actual I/O operation is + * running detached in a `Future`, so it doesn't block the current thread (but potentially + * some other thread !). If the file cannot be found or read the request is rejected. + */ + def getFromFile(fileName: String)(implicit resolver: ContentTypeResolver): Route = + getFromFile(new File(fileName)) + + /** + * Completes GET requests with the content of the given file. The actual I/O operation is + * running detached in a `Future`, so it doesn't block the current thread (but potentially + * some other thread !). If the file cannot be found or read the request is rejected. + */ + def getFromFile(file: File)(implicit resolver: ContentTypeResolver): Route = + getFromFile(file, resolver(file.getName)) + + /** + * Completes GET requests with the content of the given file. The actual I/O operation is + * running detached in a `Future`, so it doesn't block the current thread (but potentially + * some other thread !). If the file cannot be found or read the request is rejected. + */ + def getFromFile(file: File, contentType: ContentType): Route = + get { + if (file.isFile && file.canRead) + conditionalFor(file.length, file.lastModified) { + withRangeSupport { + extractSettings { settings ⇒ + complete { + HttpEntity.Default(contentType, file.length, + StreamUtils.fromInputStreamSource(new FileInputStream(file), settings.fileIODispatcher)) + } + } + } + } + else reject + } + + private def conditionalFor(length: Long, lastModified: Long): Directive0 = + extractSettings.flatMap(settings ⇒ + if (settings.fileGetConditional) { + val tag = java.lang.Long.toHexString(lastModified ^ java.lang.Long.reverse(length)) + val lastModifiedDateTime = DateTime(math.min(lastModified, System.currentTimeMillis)) + conditional(EntityTag(tag), lastModifiedDateTime) + } else pass) + + /** + * Completes GET requests with the content of the given resource. The actual I/O operation is + * running detached in a `Future`, so it doesn't block the current thread (but potentially + * some other thread !). + * If the resource cannot be found or read the Route rejects the request. + */ + def getFromResource(resourceName: String)(implicit resolver: ContentTypeResolver): Route = + getFromResource(resourceName, resolver(resourceName)) + + /** + * Completes GET requests with the content of the given resource. The actual I/O operation is + * running detached in a `Future`, so it doesn't block the current thread (but potentially + * some other thread !). + * If the resource is a directory or cannot be found or read the Route rejects the request. + */ + def getFromResource(resourceName: String, contentType: ContentType, classLoader: ClassLoader = defaultClassLoader): Route = + if (!resourceName.endsWith("/")) + get { + Option(classLoader.getResource(resourceName)) flatMap ResourceFile.apply match { + case Some(ResourceFile(url, length, lastModified)) ⇒ + conditionalFor(length, lastModified) { + withRangeSupport { + extractSettings { settings ⇒ + complete { + HttpEntity.Default(contentType, length, + StreamUtils.fromInputStreamSource(url.openStream(), settings.fileIODispatcher)) + } + } + } + } + case _ ⇒ reject // not found or directory + } + } + else reject // don't serve the content of resource "directories" + + /** + * Completes GET requests with the content of a file underneath the given directory. + * If the file cannot be read the Route rejects the request. + */ + def getFromDirectory(directoryName: String)(implicit resolver: ContentTypeResolver): Route = { + val base = withTrailingSlash(directoryName) + extractUnmatchedPath { path ⇒ + extractLog { log ⇒ + fileSystemPath(base, path, log) match { + case "" ⇒ reject + case fileName ⇒ getFromFile(fileName) + } + } + } + } + + /** + * Completes GET requests with a unified listing of the contents of all given directories. + * The actual rendering of the directory contents is performed by the in-scope `Marshaller[DirectoryListing]`. + */ + def listDirectoryContents(directories: String*)(implicit renderer: DirectoryRenderer): Route = + get { + extractRequestContext { ctx ⇒ + val path = ctx.unmatchedPath + val fullPath = ctx.request.uri.path.toString + val matchedLength = fullPath.lastIndexOf(path.toString) + require(matchedLength >= 0) + val pathPrefix = fullPath.substring(0, matchedLength) + val pathString = withTrailingSlash(fileSystemPath("/", path, ctx.log, '/')) + val dirs = directories flatMap { dir ⇒ + fileSystemPath(withTrailingSlash(dir), path, ctx.log) match { + case "" ⇒ None + case fileName ⇒ + val file = new File(fileName) + if (file.isDirectory && file.canRead) Some(file) else None + } + } + implicit val marshaller: ToEntityMarshaller[DirectoryListing] = renderer.marshaller(ctx.settings.renderVanityFooter) + + if (dirs.isEmpty) reject + else complete(DirectoryListing(pathPrefix + pathString, isRoot = pathString == "/", dirs.flatMap(_.listFiles))) + } + } + + /** + * Same as `getFromBrowseableDirectories` with only one directory. + */ + def getFromBrowseableDirectory(directory: String)(implicit renderer: DirectoryRenderer, resolver: ContentTypeResolver): Route = + getFromBrowseableDirectories(directory) + + /** + * Serves the content of the given directories as a file system browser, i.e. files are sent and directories + * served as browseable listings. + */ + def getFromBrowseableDirectories(directories: String*)(implicit renderer: DirectoryRenderer, resolver: ContentTypeResolver): Route = { + directories.map(getFromDirectory).reduceLeft(_ ~ _) ~ listDirectoryContents(directories: _*) + } + + /** + * Same as "getFromDirectory" except that the file is not fetched from the file system but rather from a + * "resource directory". + * If the requested resource is itself a directory or cannot be found or read the Route rejects the request. + */ + def getFromResourceDirectory(directoryName: String, classLoader: ClassLoader = defaultClassLoader)(implicit resolver: ContentTypeResolver): Route = { + val base = if (directoryName.isEmpty) "" else withTrailingSlash(directoryName) + + extractUnmatchedPath { path ⇒ + extractLog { log ⇒ + fileSystemPath(base, path, log, separator = '/') match { + case "" ⇒ reject + case resourceName ⇒ getFromResource(resourceName, resolver(resourceName), classLoader) + } + } + } + } + + protected[http] def defaultClassLoader: ClassLoader = classOf[ActorSystem].getClassLoader +} + +object FileAndResourceDirectives extends FileAndResourceDirectives { + private def withTrailingSlash(path: String): String = if (path endsWith "/") path else path + '/' + private def fileSystemPath(base: String, path: Uri.Path, log: LoggingAdapter, separator: Char = File.separatorChar): String = { + import java.lang.StringBuilder + @tailrec def rec(p: Uri.Path, result: StringBuilder = new StringBuilder(base)): String = + p match { + case Uri.Path.Empty ⇒ result.toString + case Uri.Path.Slash(tail) ⇒ rec(tail, result.append(separator)) + case Uri.Path.Segment(head, tail) ⇒ + if (head.indexOf('/') >= 0 || head == "..") { + log.warning("File-system path for base [{}] and Uri.Path [{}] contains suspicious path segment [{}], " + + "GET access was disallowed", base, path, head) + "" + } else rec(tail, result.append(head)) + } + rec(if (path.startsWithSlash) path.tail else path) + } + + object ResourceFile { + def apply(url: URL): Option[ResourceFile] = url.getProtocol match { + case "file" ⇒ + val file = new File(url.toURI) + if (file.isDirectory) None + else Some(ResourceFile(url, file.length(), file.lastModified())) + case "jar" ⇒ + val jarFile = url.getFile + val startIndex = if (jarFile.startsWith("file:")) 5 else 0 + val bangIndex = jarFile.indexOf("!") + val jarFilePath = jarFile.substring(startIndex, bangIndex) + val resourcePath = jarFile.substring(bangIndex + 2) + val jar = new java.util.zip.ZipFile(jarFilePath) + try { + val entry = jar.getEntry(resourcePath) + Option(jar.getInputStream(entry)) map { is ⇒ + is.close() + ResourceFile(url, entry.getSize, entry.getTime) + } + } finally jar.close() + case _ ⇒ None + } + } + case class ResourceFile(url: URL, length: Long, lastModified: Long) + + trait DirectoryRenderer { + def marshaller(renderVanityFooter: Boolean): ToEntityMarshaller[DirectoryListing] + } + trait LowLevelDirectoryRenderer { + implicit def defaultDirectoryRenderer: DirectoryRenderer = + new DirectoryRenderer { + def marshaller(renderVanityFooter: Boolean): ToEntityMarshaller[DirectoryListing] = + DirectoryListing.directoryMarshaller(renderVanityFooter) + } + } + object DirectoryRenderer extends LowLevelDirectoryRenderer { + implicit def liftMarshaller(implicit _marshaller: ToEntityMarshaller[DirectoryListing]): DirectoryRenderer = + new DirectoryRenderer { + def marshaller(renderVanityFooter: Boolean): ToEntityMarshaller[DirectoryListing] = _marshaller + } + } +} + +trait ContentTypeResolver { + def apply(fileName: String): ContentType +} + +object ContentTypeResolver { + + /** + * The default way of resolving a filename to a ContentType is by looking up the file extension in the + * registry of all defined media-types. By default all non-binary file content is assumed to be UTF-8 encoded. + */ + implicit val Default = withDefaultCharset(HttpCharsets.`UTF-8`) + + def withDefaultCharset(charset: HttpCharset): ContentTypeResolver = + new ContentTypeResolver { + def apply(fileName: String) = { + val ext = fileName.lastIndexOf('.') match { + case -1 ⇒ "" + case x ⇒ fileName.substring(x + 1) + } + val mediaType = MediaTypes.forExtension(ext) getOrElse MediaTypes.`application/octet-stream` + ContentType(mediaType) withDefaultCharset charset + } + } + + def apply(f: String ⇒ ContentType): ContentTypeResolver = + new ContentTypeResolver { + def apply(fileName: String): ContentType = f(fileName) + } +} + +case class DirectoryListing(path: String, isRoot: Boolean, files: Seq[File]) + +object DirectoryListing { + + private val html = + """ + |Index of $ + | + |

Index of $

+ |
+ |
+      |$
+ |
$ + |
+ |rendered by Akka Http on $ + |
$ + | + | + |""".stripMarginWithNewline("\n") split '$' + + def directoryMarshaller(renderVanityFooter: Boolean): ToEntityMarshaller[DirectoryListing] = + Marshaller.StringMarshaller.wrapWithEC(MediaTypes.`text/html`) { implicit ec ⇒ + listing ⇒ + val DirectoryListing(path, isRoot, files) = listing + val filesAndNames = files.map(file ⇒ file -> file.getName).sortBy(_._2) + val deduped = filesAndNames.zipWithIndex.flatMap { + case (fan @ (file, name), ix) ⇒ + if (ix == 0 || filesAndNames(ix - 1)._2 != name) Some(fan) else None + } + val (directoryFilesAndNames, fileFilesAndNames) = deduped.partition(_._1.isDirectory) + def maxNameLength(seq: Seq[(File, String)]) = if (seq.isEmpty) 0 else seq.map(_._2.length).max + val maxNameLen = math.max(maxNameLength(directoryFilesAndNames) + 1, maxNameLength(fileFilesAndNames)) + val sb = new java.lang.StringBuilder + sb.append(html(0)).append(path).append(html(1)).append(path).append(html(2)) + if (!isRoot) { + val secondToLastSlash = path.lastIndexOf('/', path.lastIndexOf('/', path.length - 1) - 1) + sb.append("../\n" format path.substring(0, secondToLastSlash)) + } + def lastModified(file: File) = DateTime(file.lastModified).toIsoLikeDateTimeString + def start(name: String) = + sb.append("").append(name).append("") + .append(" " * (maxNameLen - name.length)) + def renderDirectory(file: File, name: String) = + start(name + '/').append(" ").append(lastModified(file)).append('\n') + def renderFile(file: File, name: String) = { + val size = akka.http.impl.util.humanReadableByteCount(file.length, si = true) + start(name).append(" ").append(lastModified(file)) + sb.append(" ".substring(size.length)).append(size).append('\n') + } + for ((file, name) ← directoryFilesAndNames) renderDirectory(file, name) + for ((file, name) ← fileFilesAndNames) renderFile(file, name) + if (isRoot && files.isEmpty) sb.append("(no files)\n") + sb.append(html(3)) + if (renderVanityFooter) sb.append(html(4)).append(DateTime.now.toIsoLikeDateTimeString).append(html(5)) + sb.append(html(6)).toString + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FormFieldDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FormFieldDirectives.scala new file mode 100644 index 0000000000..a29a4b2bef --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FormFieldDirectives.scala @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.Future +import scala.util.{ Failure, Success } +import akka.http.scaladsl.unmarshalling.Unmarshaller.UnsupportedContentTypeException +import akka.http.scaladsl.common._ +import akka.http.impl.util._ +import akka.http.scaladsl.util.FastFuture._ + +trait FormFieldDirectives extends ToNameReceptacleEnhancements { + import FormFieldDirectives._ + + /** + * Rejects the request if the defined form field matcher(s) don't match. + * Otherwise the form field value(s) are extracted and passed to the inner route. + */ + def formField(pdm: FieldMagnet): pdm.Out = pdm() + + /** + * Rejects the request if the defined form field matcher(s) don't match. + * Otherwise the form field value(s) are extracted and passed to the inner route. + */ + def formFields(pdm: FieldMagnet): pdm.Out = pdm() + +} + +object FormFieldDirectives extends FormFieldDirectives { + sealed trait FieldMagnet { + type Out + def apply(): Out + } + object FieldMagnet { + implicit def apply[T](value: T)(implicit fdef: FieldDef[T]) = + new FieldMagnet { + type Out = fdef.Out + def apply() = fdef(value) + } + } + + sealed trait FieldDef[T] { + type Out + def apply(value: T): Out + } + + object FieldDef { + def fieldDef[A, B](f: A ⇒ B) = + new FieldDef[A] { + type Out = B + def apply(value: A) = f(value) + } + + import akka.http.scaladsl.unmarshalling.{ FromStrictFormFieldUnmarshaller ⇒ FSFFU, _ } + import BasicDirectives._ + import RouteDirectives._ + import FutureDirectives._ + type SFU = FromEntityUnmarshaller[StrictForm] + type FSFFOU[T] = Unmarshaller[Option[StrictForm.Field], T] + + //////////////////// "regular" formField extraction //////////////////// + + private def extractField[A, B](f: A ⇒ Directive1[B]) = fieldDef(f) + private def fieldOfForm[T](fieldName: String, fu: Unmarshaller[Option[StrictForm.Field], T])(implicit sfu: SFU): RequestContext ⇒ Future[T] = { ctx ⇒ + import ctx.executionContext + sfu(ctx.request.entity).fast.flatMap(form ⇒ fu(form field fieldName)) + } + private def filter[T](fieldName: String, fu: FSFFOU[T])(implicit sfu: SFU): Directive1[T] = { + extract(fieldOfForm(fieldName, fu)).flatMap { + onComplete(_).flatMap { + case Success(x) ⇒ provide(x) + case Failure(Unmarshaller.NoContentException) ⇒ reject(MissingFormFieldRejection(fieldName)) + case Failure(x: UnsupportedContentTypeException) ⇒ reject(UnsupportedRequestContentTypeRejection(x.supported)) + case Failure(x) ⇒ reject(MalformedFormFieldRejection(fieldName, x.getMessage.nullAsEmpty, Option(x.getCause))) + } + } + } + implicit def forString(implicit sfu: SFU, fu: FSFFU[String]) = + extractField[String, String] { fieldName ⇒ filter(fieldName, fu) } + implicit def forSymbol(implicit sfu: SFU, fu: FSFFU[String]) = + extractField[Symbol, String] { symbol ⇒ filter(symbol.name, fu) } + implicit def forNR[T](implicit sfu: SFU, fu: FSFFU[T]) = + extractField[NameReceptacle[T], T] { nr ⇒ filter(nr.name, fu) } + implicit def forNUR[T](implicit sfu: SFU) = + extractField[NameUnmarshallerReceptacle[T], T] { nr ⇒ filter(nr.name, StrictForm.Field.unmarshallerFromFSU(nr.um)) } + implicit def forNOR[T](implicit sfu: SFU, fu: FSFFOU[T]) = + extractField[NameOptionReceptacle[T], Option[T]] { nr ⇒ filter[Option[T]](nr.name, fu) } + implicit def forNDR[T](implicit sfu: SFU, fu: FSFFOU[T]) = + extractField[NameDefaultReceptacle[T], T] { nr ⇒ filter(nr.name, fu withDefaultValue nr.default) } + implicit def forNOUR[T](implicit sfu: SFU) = + extractField[NameOptionUnmarshallerReceptacle[T], Option[T]] { nr ⇒ filter[Option[T]](nr.name, StrictForm.Field.unmarshallerFromFSU(nr.um): FSFFOU[T]) } + implicit def forNDUR[T](implicit sfu: SFU) = + extractField[NameDefaultUnmarshallerReceptacle[T], T] { nr ⇒ filter(nr.name, (StrictForm.Field.unmarshallerFromFSU(nr.um): FSFFOU[T]) withDefaultValue nr.default) } + + //////////////////// required formField support //////////////////// + + private def requiredFilter[T](fieldName: String, fu: Unmarshaller[Option[StrictForm.Field], T], + requiredValue: Any)(implicit sfu: SFU): Directive0 = + extract(fieldOfForm(fieldName, fu)).flatMap { + onComplete(_).flatMap { + case Success(value) if value == requiredValue ⇒ pass + case _ ⇒ reject + } + } + implicit def forRVR[T](implicit sfu: SFU, fu: FSFFU[T]) = + fieldDef[RequiredValueReceptacle[T], Directive0] { rvr ⇒ requiredFilter(rvr.name, fu, rvr.requiredValue) } + implicit def forRVDR[T](implicit sfu: SFU) = + fieldDef[RequiredValueUnmarshallerReceptacle[T], Directive0] { rvr ⇒ requiredFilter(rvr.name, StrictForm.Field.unmarshallerFromFSU(rvr.um), rvr.requiredValue) } + + //////////////////// tuple support //////////////////// + + import akka.http.scaladsl.server.util.TupleOps._ + import akka.http.scaladsl.server.util.BinaryPolyFunc + + implicit def forTuple[T](implicit fold: FoldLeft[Directive0, T, ConvertParamDefAndConcatenate.type]) = + fieldDef[T, fold.Out](fold(pass, _)) + + object ConvertParamDefAndConcatenate extends BinaryPolyFunc { + implicit def from[P, TA, TB](implicit fdef: FieldDef[P] { type Out = Directive[TB] }, ev: Join[TA, TB]) = + at[Directive[TA], P] { (a, t) ⇒ a & fdef(t) } + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FutureDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FutureDirectives.scala new file mode 100644 index 0000000000..5bd4b10569 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/FutureDirectives.scala @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.Future +import scala.util.{ Failure, Success, Try } +import akka.http.scaladsl.marshalling.ToResponseMarshaller +import akka.http.scaladsl.server.util.Tupler +import akka.http.scaladsl.util.FastFuture._ + +// format: OFF + +trait FutureDirectives { + + /** + * "Unwraps" a ``Future[T]`` and runs its inner route after future + * completion with the future's value as an extraction of type ``Try[T]``. + */ + def onComplete[T](future: ⇒ Future[T]): Directive1[Try[T]] = + Directive { inner ⇒ ctx ⇒ + import ctx.executionContext + future.fast.transformWith(t ⇒ inner(Tuple1(t))(ctx)) + } + + /** + * "Unwraps" a ``Future[T]`` and runs its inner route after future + * completion with the future's value as an extraction of type ``T``. + * If the future fails its failure Throwable is bubbled up to the nearest + * ExceptionHandler. + * If type ``T`` is already a Tuple it is directly expanded into the respective + * number of extractions. + */ + def onSuccess(magnet: OnSuccessMagnet): Directive[magnet.Out] = magnet.directive + + /** + * "Unwraps" a ``Future[T]`` and runs its inner route when the future has failed + * with the future's failure exception as an extraction of type ``Throwable``. + * If the future succeeds the request is completed using the values marshaller + * (This directive therefore requires a marshaller for the futures type to be + * implicitly available.) + */ + def completeOrRecoverWith(magnet: CompleteOrRecoverWithMagnet): Directive1[Throwable] = magnet.directive +} + +object FutureDirectives extends FutureDirectives + +trait OnSuccessMagnet { + type Out + def directive: Directive[Out] +} + +object OnSuccessMagnet { + implicit def apply[T](future: ⇒ Future[T])(implicit tupler: Tupler[T]) = + new OnSuccessMagnet { + type Out = tupler.Out + val directive = Directive[tupler.Out] { inner ⇒ ctx ⇒ + import ctx.executionContext + future.fast.flatMap(t ⇒ inner(tupler(t))(ctx)) + }(tupler.OutIsTuple) + } +} + +trait CompleteOrRecoverWithMagnet { + def directive: Directive1[Throwable] +} + +object CompleteOrRecoverWithMagnet { + implicit def apply[T](future: ⇒ Future[T])(implicit m: ToResponseMarshaller[T]) = + new CompleteOrRecoverWithMagnet { + val directive = Directive[Tuple1[Throwable]] { inner ⇒ ctx ⇒ + import ctx.executionContext + future.fast.transformWith { + case Success(res) ⇒ ctx.complete(res) + case Failure(error) ⇒ inner(Tuple1(error))(ctx) + } + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala new file mode 100644 index 0000000000..5c670dd4d3 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HeaderDirectives.scala @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.util.control.NonFatal +import akka.http.scaladsl.server.util.ClassMagnet +import akka.http.scaladsl.model._ +import akka.http.impl.util._ + +trait HeaderDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Extracts an HTTP header value using the given function. If the function result is undefined for all headers the + * request is rejected with an empty rejection set. If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def headerValue[T](f: HttpHeader ⇒ Option[T]): Directive1[T] = { + val protectedF: HttpHeader ⇒ Option[Either[Rejection, T]] = header ⇒ + try f(header).map(Right.apply) + catch { + case NonFatal(e) ⇒ Some(Left(MalformedHeaderRejection(header.name, e.getMessage.nullAsEmpty, Some(e)))) + } + + extract(_.request.headers.collectFirst(Function.unlift(protectedF))).flatMap { + case Some(Right(a)) ⇒ provide(a) + case Some(Left(rejection)) ⇒ reject(rejection) + case None ⇒ reject + } + } + + /** + * Extracts an HTTP header value using the given partial function. If the function is undefined for all headers the + * request is rejected with an empty rejection set. + */ + def headerValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[T] = headerValue(pf.lift) + + /** + * Extracts the value of the HTTP request header with the given name. + * If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByName(headerName: Symbol): Directive1[String] = headerValueByName(headerName.toString) + + /** + * Extracts the value of the HTTP request header with the given name. + * If no header with a matching name is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByName(headerName: String): Directive1[String] = + headerValue(optionalValue(headerName.toLowerCase)) | reject(MissingHeaderRejection(headerName)) + + /** + * Extracts the HTTP request header of the given type. + * If no header with a matching type is found the request is rejected with a [[spray.routing.MissingHeaderRejection]]. + */ + def headerValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[T] = + headerValuePF(magnet.extractPF) | reject(MissingHeaderRejection(magnet.runtimeClass.getSimpleName)) + + /** + * Extracts an optional HTTP header value using the given function. + * If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def optionalHeaderValue[T](f: HttpHeader ⇒ Option[T]): Directive1[Option[T]] = + headerValue(f).map(Some(_): Option[T]).recoverPF { + case Nil ⇒ provide(None) + } + + /** + * Extracts an optional HTTP header value using the given partial function. + * If the given function throws an exception the request is rejected + * with a [[spray.routing.MalformedHeaderRejection]]. + */ + def optionalHeaderValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[Option[T]] = + optionalHeaderValue(pf.lift) + + /** + * Extracts the value of the optional HTTP request header with the given name. + */ + def optionalHeaderValueByName(headerName: Symbol): Directive1[Option[String]] = + optionalHeaderValueByName(headerName.toString) + + /** + * Extracts the value of the optional HTTP request header with the given name. + */ + def optionalHeaderValueByName(headerName: String): Directive1[Option[String]] = { + val lowerCaseName = headerName.toLowerCase + extract(_.request.headers.collectFirst { + case HttpHeader(`lowerCaseName`, value) ⇒ value + }) + } + + /** + * Extract the header value of the optional HTTP request header with the given type. + */ + def optionalHeaderValueByType[T <: HttpHeader](magnet: ClassMagnet[T]): Directive1[Option[T]] = + optionalHeaderValuePF(magnet.extractPF) + + private def optionalValue(lowerCaseName: String): HttpHeader ⇒ Option[String] = { + case HttpHeader(`lowerCaseName`, value) ⇒ Some(value) + case _ ⇒ None + } +} + +object HeaderDirectives extends HeaderDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HostDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HostDirectives.scala new file mode 100644 index 0000000000..de52962b6d --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/HostDirectives.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.util.matching.Regex +import akka.http.impl.util._ + +trait HostDirectives { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Extracts the hostname part of the Host header value in the request. + */ + def extractHost: Directive1[String] = HostDirectives._extractHost + + /** + * Rejects all requests with a host name different from the given ones. + */ + def host(hostNames: String*): Directive0 = host(hostNames.contains(_)) + + /** + * Rejects all requests for whose host name the given predicate function returns false. + */ + def host(predicate: String ⇒ Boolean): Directive0 = extractHost.require(predicate) + + /** + * Rejects all requests with a host name that doesn't have a prefix matching the given regular expression. + * For all matching requests the prefix string matching the regex is extracted and passed to the inner route. + * If the regex contains a capturing group only the string matched by this group is extracted. + * If the regex contains more than one capturing group an IllegalArgumentException is thrown. + */ + def host(regex: Regex): Directive1[String] = { + def forFunc(regexMatch: String ⇒ Option[String]): Directive1[String] = { + extractHost.flatMap { name ⇒ + regexMatch(name) match { + case Some(matched) ⇒ provide(matched) + case None ⇒ reject + } + } + } + + regex.groupCount match { + case 0 ⇒ forFunc(regex.findPrefixOf(_)) + case 1 ⇒ forFunc(regex.findPrefixMatchOf(_).map(_.group(1))) + case _ ⇒ throw new IllegalArgumentException("Path regex '" + regex.pattern.pattern + + "' must not contain more than one capturing group") + } + } + +} + +object HostDirectives extends HostDirectives { + import BasicDirectives._ + + private val _extractHost: Directive1[String] = + extract(_.request.uri.authority.host.address) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MarshallingDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MarshallingDirectives.scala new file mode 100644 index 0000000000..e24351fa23 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MarshallingDirectives.scala @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.concurrent.Promise +import scala.util.{ Failure, Success } +import akka.http.scaladsl.marshalling.ToResponseMarshaller +import akka.http.scaladsl.unmarshalling.{ Unmarshaller, FromRequestUnmarshaller } +import akka.http.impl.util._ + +trait MarshallingDirectives { + import BasicDirectives._ + import FutureDirectives._ + import RouteDirectives._ + + /** + * Unmarshalls the requests entity to the given type passes it to its inner Route. + * If there is a problem with unmarshalling the request is rejected with the [[Rejection]] + * produced by the unmarshaller. + */ + def entity[T](um: FromRequestUnmarshaller[T]): Directive1[T] = + extractRequestContext.flatMap[Tuple1[T]] { ctx ⇒ + import ctx.executionContext + onComplete(um(ctx.request)) flatMap { + case Success(value) ⇒ provide(value) + case Failure(Unmarshaller.NoContentException) ⇒ reject(RequestEntityExpectedRejection) + case Failure(Unmarshaller.UnsupportedContentTypeException(x)) ⇒ reject(UnsupportedRequestContentTypeRejection(x)) + case Failure(x: IllegalArgumentException) ⇒ reject(ValidationRejection(x.getMessage.nullAsEmpty, Some(x))) + case Failure(x) ⇒ reject(MalformedRequestContentRejection(x.getMessage.nullAsEmpty, Option(x.getCause))) + } + } & cancelRejections(RequestEntityExpectedRejection.getClass, classOf[UnsupportedRequestContentTypeRejection]) + + /** + * Returns the in-scope [[FromRequestUnmarshaller]] for the given type. + */ + def as[T](implicit um: FromRequestUnmarshaller[T]) = um + + /** + * Uses the marshaller for the given type to produce a completion function that is passed to its inner function. + * You can use it do decouple marshaller resolution from request completion. + */ + def completeWith[T](marshaller: ToResponseMarshaller[T])(inner: (T ⇒ Unit) ⇒ Unit): Route = + extractExecutionContext { implicit ec ⇒ + implicit val m = marshaller + complete { + val promise = Promise[T]() + inner(promise.success(_)) + promise.future + } + } + + /** + * Returns the in-scope Marshaller for the given type. + */ + def instanceOf[T](implicit m: ToResponseMarshaller[T]): ToResponseMarshaller[T] = m + + /** + * Completes the request using the given function. The input to the function is produced with the in-scope + * entity unmarshaller and the result value of the function is marshalled with the in-scope marshaller. + */ + def handleWith[A, B](f: A ⇒ B)(implicit um: FromRequestUnmarshaller[A], m: ToResponseMarshaller[B]): Route = + entity(um) { a ⇒ complete(f(a)) } +} + +object MarshallingDirectives extends MarshallingDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MethodDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MethodDirectives.scala new file mode 100644 index 0000000000..08f6f98491 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MethodDirectives.scala @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.{ StatusCodes, HttpMethod } +import akka.http.scaladsl.model.HttpMethods._ + +trait MethodDirectives { + import BasicDirectives._ + import RouteDirectives._ + import ParameterDirectives._ + import MethodDirectives._ + + /** + * A route filter that rejects all non-DELETE requests. + */ + def delete: Directive0 = _delete + + /** + * A route filter that rejects all non-GET requests. + */ + def get: Directive0 = _get + + /** + * A route filter that rejects all non-HEAD requests. + */ + def head: Directive0 = _head + + /** + * A route filter that rejects all non-OPTIONS requests. + */ + def options: Directive0 = _options + + /** + * A route filter that rejects all non-PATCH requests. + */ + def patch: Directive0 = _patch + + /** + * A route filter that rejects all non-POST requests. + */ + def post: Directive0 = _post + + /** + * A route filter that rejects all non-PUT requests. + */ + def put: Directive0 = _put + + /** + * Extracts the request method. + */ + def extractMethod: Directive1[HttpMethod] = _extractMethod + + /** + * Rejects all requests whose HTTP method does not match the given one. + */ + def method(httpMethod: HttpMethod): Directive0 = + extractMethod.flatMap[Unit] { + case `httpMethod` ⇒ pass + case _ ⇒ reject(MethodRejection(httpMethod)) + } & cancelRejections(classOf[MethodRejection]) + + /** + * Changes the HTTP method of the request to the value of the specified query string parameter. If the query string + * parameter is not specified this directive has no effect. If the query string is specified as something that is not + * a HTTP method, then this directive completes the request with a `501 Not Implemented` response. + * + * This directive is useful for: + * - Use in combination with JSONP (JSONP only supports GET) + * - Supporting older browsers that lack support for certain HTTP methods. E.g. IE8 does not support PATCH + */ + def overrideMethodWithParameter(paramName: String): Directive0 = + parameter(paramName?) flatMap { + case Some(method) ⇒ + getForKey(method.toUpperCase) match { + case Some(m) ⇒ mapRequest(_.copy(method = m)) + case _ ⇒ complete(StatusCodes.NotImplemented) + } + case None ⇒ pass + } +} + +object MethodDirectives extends MethodDirectives { + private val _extractMethod: Directive1[HttpMethod] = + BasicDirectives.extract(_.request.method) + + // format: OFF + private val _delete : Directive0 = method(DELETE) + private val _get : Directive0 = method(GET) + private val _head : Directive0 = method(HEAD) + private val _options: Directive0 = method(OPTIONS) + private val _patch : Directive0 = method(PATCH) + private val _post : Directive0 = method(POST) + private val _put : Directive0 = method(PUT) + // format: ON +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MiscDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MiscDirectives.scala new file mode 100644 index 0000000000..42d531e816 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MiscDirectives.scala @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import headers._ + +trait MiscDirectives { + import RouteDirectives._ + + /** + * Returns a Directive which checks the given condition before passing on the [[spray.routing.RequestContext]] to + * its inner Route. If the condition fails the route is rejected with a [[spray.routing.ValidationRejection]]. + */ + def validate(check: ⇒ Boolean, errorMsg: String): Directive0 = + Directive { inner ⇒ if (check) inner() else reject(ValidationRejection(errorMsg)) } + + /** + * Directive extracting the IP of the client from either the X-Forwarded-For, Remote-Address or X-Real-IP header + * (in that order of priority). + */ + def extractClientIP: Directive1[RemoteAddress] = MiscDirectives._extractClientIP + + /** + * Rejects the request if its entity is not empty. + */ + def requestEntityEmpty: Directive0 = MiscDirectives._requestEntityEmpty + + /** + * Rejects empty requests with a RequestEntityExpectedRejection. + * Non-empty requests are passed on unchanged to the inner route. + */ + def requestEntityPresent: Directive0 = MiscDirectives._requestEntityPresent + + /** + * Converts responses with an empty entity into (empty) rejections. + * This way you can, for example, have the marshalling of a ''None'' option be treated as if the request could + * not be matched. + */ + def rejectEmptyResponse: Directive0 = MiscDirectives._rejectEmptyResponse +} + +object MiscDirectives extends MiscDirectives { + import BasicDirectives._ + import HeaderDirectives._ + import RouteDirectives._ + import RouteResult._ + + private val _extractClientIP: Directive1[RemoteAddress] = + headerValuePF { case `X-Forwarded-For`(Seq(address, _*)) ⇒ address } | + headerValuePF { case `Remote-Address`(address) ⇒ address } | + headerValuePF { case h if h.is("x-real-ip") ⇒ RemoteAddress(h.value) } + + private val _requestEntityEmpty: Directive0 = + extract(_.request.entity.isKnownEmpty).flatMap(if (_) pass else reject) + + private val _requestEntityPresent: Directive0 = + extract(_.request.entity.isKnownEmpty).flatMap(if (_) reject else pass) + + private val _rejectEmptyResponse: Directive0 = + mapRouteResult { + case Complete(response) if response.entity.isKnownEmpty ⇒ Rejected(Nil) + case x ⇒ x + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ParameterDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ParameterDirectives.scala new file mode 100644 index 0000000000..677fcf2538 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/ParameterDirectives.scala @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.collection.immutable +import scala.util.{ Failure, Success } +import akka.http.scaladsl.common._ +import akka.http.impl.util._ + +trait ParameterDirectives extends ToNameReceptacleEnhancements { + import ParameterDirectives._ + + /** + * Extracts the requests query parameters as a Map[String, String]. + */ + def parameterMap: Directive1[Map[String, String]] = _parameterMap + + /** + * Extracts the requests query parameters as a Map[String, List[String]]. + */ + def parameterMultiMap: Directive1[Map[String, List[String]]] = _parameterMultiMap + + /** + * Extracts the requests query parameters as a Seq[(String, String)]. + */ + def parameterSeq: Directive1[immutable.Seq[(String, String)]] = _parameterSeq + + /** + * Rejects the request if the defined query parameter matcher(s) don't match. + * Otherwise the parameter value(s) are extracted and passed to the inner route. + */ + def parameter(pdm: ParamMagnet): pdm.Out = pdm() + + /** + * Rejects the request if the defined query parameter matcher(s) don't match. + * Otherwise the parameter value(s) are extracted and passed to the inner route. + */ + def parameters(pdm: ParamMagnet): pdm.Out = pdm() + +} + +object ParameterDirectives extends ParameterDirectives { + import BasicDirectives._ + + private val _parameterMap: Directive1[Map[String, String]] = + extract(_.request.uri.query.toMap) + + private val _parameterMultiMap: Directive1[Map[String, List[String]]] = + extract(_.request.uri.query.toMultiMap) + + private val _parameterSeq: Directive1[immutable.Seq[(String, String)]] = + extract(_.request.uri.query.toSeq) + + sealed trait ParamMagnet { + type Out + def apply(): Out + } + object ParamMagnet { + implicit def apply[T](value: T)(implicit pdef: ParamDef[T]) = + new ParamMagnet { + type Out = pdef.Out + def apply() = pdef(value) + } + } + + sealed trait ParamDef[T] { + type Out + def apply(value: T): Out + } + object ParamDef { + def paramDef[A, B](f: A ⇒ B) = + new ParamDef[A] { + type Out = B + def apply(value: A) = f(value) + } + + import akka.http.scaladsl.unmarshalling.{ FromStringUnmarshaller ⇒ FSU, _ } + import BasicDirectives._ + import RouteDirectives._ + import FutureDirectives._ + type FSOU[T] = Unmarshaller[Option[String], T] + + //////////////////// "regular" parameter extraction ////////////////////// + + private def extractParameter[A, B](f: A ⇒ Directive1[B]) = paramDef(f) + private def filter[T](paramName: String, fsou: FSOU[T]): Directive1[T] = + extractRequestContext flatMap { ctx ⇒ + import ctx.executionContext + onComplete(fsou(ctx.request.uri.query get paramName)) flatMap { + case Success(x) ⇒ provide(x) + case Failure(Unmarshaller.NoContentException) ⇒ reject(MissingQueryParamRejection(paramName)) + case Failure(x) ⇒ reject(MalformedQueryParamRejection(paramName, x.getMessage.nullAsEmpty, Option(x.getCause))) + } + } + implicit def forString(implicit fsu: FSU[String]) = + extractParameter[String, String] { string ⇒ filter(string, fsu) } + implicit def forSymbol(implicit fsu: FSU[String]) = + extractParameter[Symbol, String] { symbol ⇒ filter(symbol.name, fsu) } + implicit def forNR[T](implicit fsu: FSU[T]) = + extractParameter[NameReceptacle[T], T] { nr ⇒ filter(nr.name, fsu) } + implicit def forNUR[T] = + extractParameter[NameUnmarshallerReceptacle[T], T] { nr ⇒ filter(nr.name, nr.um) } + implicit def forNOR[T](implicit fsou: FSOU[T]) = + extractParameter[NameOptionReceptacle[T], Option[T]] { nr ⇒ filter[Option[T]](nr.name, fsou) } + implicit def forNDR[T](implicit fsou: FSOU[T]) = + extractParameter[NameDefaultReceptacle[T], T] { nr ⇒ filter[T](nr.name, fsou withDefaultValue nr.default) } + implicit def forNOUR[T] = + extractParameter[NameOptionUnmarshallerReceptacle[T], Option[T]] { nr ⇒ filter(nr.name, nr.um: FSOU[T]) } + implicit def forNDUR[T] = + extractParameter[NameDefaultUnmarshallerReceptacle[T], T] { nr ⇒ filter[T](nr.name, (nr.um: FSOU[T]) withDefaultValue nr.default) } + + //////////////////// required parameter support //////////////////// + + private def requiredFilter[T](paramName: String, fsou: FSOU[T], requiredValue: Any): Directive0 = + extractRequestContext flatMap { ctx ⇒ + import ctx.executionContext + onComplete(fsou(ctx.request.uri.query get paramName)) flatMap { + case Success(value) if value == requiredValue ⇒ pass + case _ ⇒ reject + } + } + implicit def forRVR[T](implicit fsu: FSU[T]) = + paramDef[RequiredValueReceptacle[T], Directive0] { rvr ⇒ requiredFilter(rvr.name, fsu, rvr.requiredValue) } + implicit def forRVDR[T] = + paramDef[RequiredValueUnmarshallerReceptacle[T], Directive0] { rvr ⇒ requiredFilter(rvr.name, rvr.um, rvr.requiredValue) } + + //////////////////// tuple support //////////////////// + + import akka.http.scaladsl.server.util.TupleOps._ + import akka.http.scaladsl.server.util.BinaryPolyFunc + + implicit def forTuple[T](implicit fold: FoldLeft[Directive0, T, ConvertParamDefAndConcatenate.type]) = + paramDef[T, fold.Out](fold(BasicDirectives.pass, _)) + + object ConvertParamDefAndConcatenate extends BinaryPolyFunc { + implicit def from[P, TA, TB](implicit pdef: ParamDef[P] { type Out = Directive[TB] }, ev: Join[TA, TB]) = + at[Directive[TA], P] { (a, t) ⇒ a & pdef(t) } + } + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/PathDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/PathDirectives.scala new file mode 100644 index 0000000000..3c2896056e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/PathDirectives.scala @@ -0,0 +1,169 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.common.ToNameReceptacleEnhancements +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.Uri.Path + +trait PathDirectives extends PathMatchers with ImplicitPathMatcherConstruction with ToNameReceptacleEnhancements { + import BasicDirectives._ + import RouteDirectives._ + import PathMatcher._ + + /** + * Consumes a leading slash from the unmatched path of the [[akka.http.scaladsl.server.RequestContext]] + * before applying the given matcher. The matcher has to match the remaining path completely + * or leave only a single trailing slash. + * If matched the value extracted by the PathMatcher is extracted on the directive level. + */ + def path[L](pm: PathMatcher[L]): Directive[L] = pathPrefix(pm ~ PathEnd) + + /** + * Consumes a leading slash from the unmatched path of the [[akka.http.scaladsl.server.RequestContext]] + * before applying the given matcher. The matcher has to match a prefix of the remaining path. + * If matched the value extracted by the PathMatcher is extracted on the directive level. + */ + def pathPrefix[L](pm: PathMatcher[L]): Directive[L] = rawPathPrefix(Slash ~ pm) + + /** + * Applies the given matcher directly to the unmatched path of the [[akka.http.scaladsl.server.RequestContext]] + * (i.e. without implicitly consuming a leading slash). + * The matcher has to match a prefix of the remaining path. + * If matched the value extracted by the PathMatcher is extracted on the directive level. + */ + def rawPathPrefix[L](pm: PathMatcher[L]): Directive[L] = { + implicit def LIsTuple = pm.ev + extract(ctx ⇒ pm(ctx.unmatchedPath)).flatMap { + case Matched(rest, values) ⇒ tprovide(values) & mapRequestContext(_ withUnmatchedPath rest) + case Unmatched ⇒ reject + } + } + + /** + * Checks whether the unmatchedPath of the [[akka.http.scaladsl.server.RequestContext]] has a prefix matched by the + * given PathMatcher. In analogy to the `pathPrefix` directive a leading slash is implied. + */ + def pathPrefixTest[L](pm: PathMatcher[L]): Directive[L] = rawPathPrefixTest(Slash ~ pm) + + /** + * Checks whether the unmatchedPath of the [[akka.http.scaladsl.server.RequestContext]] has a prefix matched by the + * given PathMatcher. However, as opposed to the `pathPrefix` directive the matched path is not + * actually "consumed". + */ + def rawPathPrefixTest[L](pm: PathMatcher[L]): Directive[L] = { + implicit def LIsTuple = pm.ev + extract(ctx ⇒ pm(ctx.unmatchedPath)).flatMap { + case Matched(_, values) ⇒ tprovide(values) + case Unmatched ⇒ reject + } + } + + /** + * Rejects the request if the unmatchedPath of the [[akka.http.scaladsl.server.RequestContext]] does not have a suffix + * matched the given PathMatcher. If matched the value extracted by the PathMatcher is extracted + * and the matched parts of the path are consumed. + * Note that, for efficiency reasons, the given PathMatcher must match the desired suffix in reversed-segment + * order, i.e. `pathSuffix("baz" / "bar")` would match `/foo/bar/baz`! + */ + def pathSuffix[L](pm: PathMatcher[L]): Directive[L] = { + implicit def LIsTuple = pm.ev + extract(ctx ⇒ pm(ctx.unmatchedPath.reverse)).flatMap { + case Matched(rest, values) ⇒ tprovide(values) & mapRequestContext(_.withUnmatchedPath(rest.reverse)) + case Unmatched ⇒ reject + } + } + + /** + * Checks whether the unmatchedPath of the [[akka.http.scaladsl.server.RequestContext]] has a suffix matched by the + * given PathMatcher. However, as opposed to the pathSuffix directive the matched path is not + * actually "consumed". + * Note that, for efficiency reasons, the given PathMatcher must match the desired suffix in reversed-segment + * order, i.e. `pathSuffixTest("baz" / "bar")` would match `/foo/bar/baz`! + */ + def pathSuffixTest[L](pm: PathMatcher[L]): Directive[L] = { + implicit def LIsTuple = pm.ev + extract(ctx ⇒ pm(ctx.unmatchedPath.reverse)).flatMap { + case Matched(_, values) ⇒ tprovide(values) + case Unmatched ⇒ reject + } + } + + /** + * Rejects the request if the unmatchedPath of the [[akka.http.scaladsl.server.RequestContext]] is non-empty, + * or said differently: only passes on the request to its inner route if the request path + * has been matched completely. + */ + def pathEnd: Directive0 = rawPathPrefix(PathEnd) + + /** + * Only passes on the request to its inner route if the request path has been matched + * completely or only consists of exactly one remaining slash. + * + * Note that trailing slash and non-trailing slash URLs are '''not''' the same, although they often serve + * the same content. It is recommended to serve only one URL version and make the other redirect to it using + * [[redirectToTrailingSlashIfMissing]] or [[redirectToNoTrailingSlashIfPresent]] directive. + * + * For example: + * {{{ + * def route = { + * // redirect '/users/' to '/users', '/users/:userId/' to '/users/:userId' + * redirectToNoTrailingSlashIfPresent(Found) { + * pathPrefix("users") { + * pathEnd { + * // user list ... + * } ~ + * path(UUID) { userId => + * // user profile ... + * } + * } + * } + * } + * }}} + * + * For further information, refer to: + * [[http://googlewebmastercentral.blogspot.de/2010/04/to-slash-or-not-to-slash.html]] + */ + def pathEndOrSingleSlash: Directive0 = rawPathPrefix(Slash.? ~ PathEnd) + + /** + * Only passes on the request to its inner route if the request path + * consists of exactly one remaining slash. + */ + def pathSingleSlash: Directive0 = pathPrefix(PathEnd) + + /** + * If the request path doesn't end with a slash, redirect to the same uri with trailing slash in the path. + * + * '''Caveat''': [[path]] without trailing slash and [[pathEnd]] directives will not match inside of this directive. + */ + def redirectToTrailingSlashIfMissing(redirectionType: StatusCodes.Redirection): Directive0 = + extractUri.flatMap { uri ⇒ + if (uri.path.endsWithSlash) pass + else { + val newPath = uri.path ++ Path.SingleSlash + val newUri = uri.withPath(newPath) + redirect(newUri, redirectionType) + } + } + + /** + * If the request path ends with a slash, redirect to the same uri without trailing slash in the path. + * + * '''Caveat''': [[pathSingleSlash]] directive will not match inside of this directive. + */ + def redirectToNoTrailingSlashIfPresent(redirectionType: StatusCodes.Redirection): Directive0 = + extractUri.flatMap { uri ⇒ + if (uri.path.endsWithSlash) { + val newPath = uri.path.reverse.tail.reverse + val newUri = uri.withPath(newPath) + redirect(newUri, redirectionType) + } else pass + } + +} + +object PathDirectives extends PathDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala new file mode 100644 index 0000000000..a210942055 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RangeDirectives.scala @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.StatusCodes._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.RouteResult.Complete +import akka.http.impl.util._ +import akka.stream.scaladsl.Source + +import scala.collection.immutable + +trait RangeDirectives { + import akka.http.scaladsl.server.directives.BasicDirectives._ + import akka.http.scaladsl.server.directives.RouteDirectives._ + + /** + * Answers GET requests with an `Accept-Ranges: bytes` header and converts HttpResponses coming back from its inner + * route into partial responses if the initial request contained a valid `Range` request header. The requested + * byte-ranges may be coalesced. + * This directive is transparent to non-GET requests + * Rejects requests with unsatisfiable ranges `UnsatisfiableRangeRejection`. + * Rejects requests with too many expected ranges. + * + * Note: if you want to combine this directive with `conditional(...)` you need to put + * it on the *inside* of the `conditional(...)` directive, i.e. `conditional(...)` must be + * on a higher level in your route structure in order to function correctly. + * + * @see https://tools.ietf.org/html/rfc7233 + */ + def withRangeSupport: Directive0 = + extractRequestContext.flatMap { ctx ⇒ + import ctx.flowMaterializer + val settings = ctx.settings + implicit val log = ctx.log + import settings.{ rangeCountLimit, rangeCoalescingThreshold } + + class IndexRange(val start: Long, val end: Long) { + def length = end - start + def apply(entity: UniversalEntity): UniversalEntity = entity.transformDataBytes(length, StreamUtils.sliceBytesTransformer(start, length)) + def distance(other: IndexRange) = mergedEnd(other) - mergedStart(other) - (length + other.length) + def mergeWith(other: IndexRange) = new IndexRange(mergedStart(other), mergedEnd(other)) + def contentRange(entityLength: Long) = ContentRange(start, end - 1, entityLength) + private def mergedStart(other: IndexRange) = math.min(start, other.start) + private def mergedEnd(other: IndexRange) = math.max(end, other.end) + } + + def indexRange(entityLength: Long)(range: ByteRange): IndexRange = + range match { + case ByteRange.Slice(start, end) ⇒ new IndexRange(start, math.min(end + 1, entityLength)) + case ByteRange.FromOffset(first) ⇒ new IndexRange(first, entityLength) + case ByteRange.Suffix(suffixLength) ⇒ new IndexRange(math.max(0, entityLength - suffixLength), entityLength) + } + + // See comment of the `range-coalescing-threshold` setting in `reference.conf` for the rationale of this behavior. + def coalesceRanges(iRanges: Seq[IndexRange]): Seq[IndexRange] = + iRanges.foldLeft(Seq.empty[IndexRange]) { (acc, iRange) ⇒ + val (mergeCandidates, otherCandidates) = acc.partition(_.distance(iRange) <= rangeCoalescingThreshold) + val merged = mergeCandidates.foldLeft(iRange)(_ mergeWith _) + otherCandidates :+ merged + } + + def multipartRanges(ranges: Seq[ByteRange], entity: UniversalEntity): Multipart.ByteRanges = { + val length = entity.contentLength + val iRanges: Seq[IndexRange] = ranges.map(indexRange(length)) + + // It's only possible to run once over the input entity data stream because it's not known if the + // source is reusable. + // Therefore, ranges need to be sorted to prevent that some selected ranges already start to accumulate data + // but cannot be sent out because another range is blocking the queue. + val coalescedRanges = coalesceRanges(iRanges).sortBy(_.start) + val bodyPartTransformers = coalescedRanges.map(ir ⇒ StreamUtils.sliceBytesTransformer(ir.start, ir.length)).toVector + val bodyPartByteStreams = StreamUtils.transformMultiple(entity.dataBytes, bodyPartTransformers) + val bodyParts = (coalescedRanges, bodyPartByteStreams).zipped.map { (range, bytes) ⇒ + Multipart.ByteRanges.BodyPart(range.contentRange(length), HttpEntity(entity.contentType, range.length, bytes)) + } + Multipart.ByteRanges(Source(bodyParts.toVector)) + } + + def rangeResponse(range: ByteRange, entity: UniversalEntity, length: Long, headers: immutable.Seq[HttpHeader]) = { + val aiRange = indexRange(length)(range) + HttpResponse(PartialContent, `Content-Range`(aiRange.contentRange(length)) +: headers, aiRange(entity)) + } + + def satisfiable(entityLength: Long)(range: ByteRange): Boolean = + range match { + case ByteRange.Slice(firstPos, _) ⇒ firstPos < entityLength + case ByteRange.FromOffset(firstPos) ⇒ firstPos < entityLength + case ByteRange.Suffix(length) ⇒ length > 0 + } + def universal(entity: HttpEntity): Option[UniversalEntity] = entity match { + case u: UniversalEntity ⇒ Some(u) + case _ ⇒ None + } + + def applyRanges(ranges: Seq[ByteRange]): Directive0 = + extractRequestContext.flatMap { ctx ⇒ + mapRouteResultWithPF { + case Complete(HttpResponse(OK, headers, entity, protocol)) ⇒ + universal(entity) match { + case Some(entity) ⇒ + val length = entity.contentLength + ranges.filter(satisfiable(length)) match { + case Nil ⇒ ctx.reject(UnsatisfiableRangeRejection(ranges, length)) + case Seq(satisfiableRange) ⇒ ctx.complete(rangeResponse(satisfiableRange, entity, length, headers)) + case satisfiableRanges ⇒ + ctx.complete(PartialContent, headers, multipartRanges(satisfiableRanges, entity)) + } + case None ⇒ + // Ranges not supported for Chunked or CloseDelimited responses + ctx.reject(UnsatisfiableRangeRejection(ranges, -1)) // FIXME: provide better error + } + } + } + + def rangeHeaderOfGetRequests(ctx: RequestContext): Option[Range] = + if (ctx.request.method == HttpMethods.GET) ctx.request.header[Range] else None + + extract(rangeHeaderOfGetRequests).flatMap { + case Some(Range(RangeUnits.Bytes, ranges)) ⇒ + if (ranges.size <= rangeCountLimit) applyRanges(ranges) & RangeDirectives.respondWithAcceptByteRangesHeader + else reject(TooManyRangesRejection(rangeCountLimit)) + case _ ⇒ MethodDirectives.get & RangeDirectives.respondWithAcceptByteRangesHeader | pass + } + } +} + +object RangeDirectives extends RangeDirectives { + private val respondWithAcceptByteRangesHeader: Directive0 = + RespondWithDirectives.respondWithHeader(`Accept-Ranges`(RangeUnits.Bytes)) +} + diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RespondWithDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RespondWithDirectives.scala new file mode 100644 index 0000000000..12557ad8fd --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RespondWithDirectives.scala @@ -0,0 +1,53 @@ +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model._ +import scala.collection.immutable + +trait RespondWithDirectives { + import BasicDirectives._ + + /** + * Overrides the given response status on all HTTP responses of its inner Route. + */ + def overrideStatusCode(responseStatus: StatusCode): Directive0 = + mapResponse(_.copy(status = responseStatus)) + + /** + * Unconditionally adds the given response header to all HTTP responses of its inner Route. + */ + def respondWithHeader(responseHeader: HttpHeader): Directive0 = respondWithHeaders(responseHeader) + + /** + * Adds the given response header to all HTTP responses of its inner Route, + * if the response from the inner Route doesn't already contain a header with the same name. + */ + def respondWithDefaultHeader(responseHeader: HttpHeader): Directive0 = respondWithDefaultHeaders(responseHeader) + + /** + * Unconditionally adds the given response headers to all HTTP responses of its inner Route. + */ + def respondWithHeaders(responseHeaders: HttpHeader*): Directive0 = + respondWithHeaders(responseHeaders.toList) + + /** + * Unconditionally adds the given response headers to all HTTP responses of its inner Route. + */ + def respondWithHeaders(responseHeaders: immutable.Seq[HttpHeader]): Directive0 = + mapResponseHeaders(responseHeaders ++ _) + + /** + * Adds the given response headers to all HTTP responses of its inner Route, + * if a header already exists it is not added again. + */ + def respondWithDefaultHeaders(responseHeaders: HttpHeader*): Directive0 = + respondWithDefaultHeaders(responseHeaders.toList) + + /* Adds the given response headers to all HTTP responses of its inner Route, + * if a header already exists it is not added again. + */ + def respondWithDefaultHeaders(responseHeaders: immutable.Seq[HttpHeader]): Directive0 = + mapResponse(_.withDefaultHeaders(responseHeaders)) +} + +object RespondWithDirectives extends RespondWithDirectives diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RouteDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RouteDirectives.scala new file mode 100644 index 0000000000..17dddac09c --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/RouteDirectives.scala @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.marshalling.ToResponseMarshallable +import akka.http.scaladsl.model._ +import StatusCodes._ + +trait RouteDirectives { + + /** + * Rejects the request with an empty set of rejections. + */ + def reject: StandardRoute = RouteDirectives._reject + + /** + * Rejects the request with the given rejections. + */ + def reject(rejections: Rejection*): StandardRoute = + StandardRoute(_.reject(rejections: _*)) + + /** + * Completes the request with redirection response of the given type to the given URI. + */ + def redirect(uri: Uri, redirectionType: Redirection): StandardRoute = + StandardRoute { + _. //# red-impl + complete { + HttpResponse( + status = redirectionType, + headers = headers.Location(uri) :: Nil, + entity = redirectionType.htmlTemplate match { + case "" ⇒ HttpEntity.Empty + case template ⇒ HttpEntity(MediaTypes.`text/html`, template format uri) + }) + } + //# + } + + /** + * Completes the request using the given arguments. + */ + def complete(m: ⇒ ToResponseMarshallable): StandardRoute = + StandardRoute(_.complete(m)) + + /** + * Bubbles the given error up the response chain, where it is dealt with by the closest `handleExceptions` + * directive and its ExceptionHandler. + */ + def failWith(error: Throwable): StandardRoute = + StandardRoute(_.fail(error)) +} + +object RouteDirectives extends RouteDirectives { + private val _reject = StandardRoute(_.reject()) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SchemeDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SchemeDirectives.scala new file mode 100644 index 0000000000..33b3c5b46e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SchemeDirectives.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +trait SchemeDirectives { + import BasicDirectives._ + + /** + * Extracts the Uri scheme from the request. + */ + def extractScheme: Directive1[String] = SchemeDirectives._extractScheme + + /** + * Rejects all requests whose Uri scheme does not match the given one. + */ + def scheme(name: String): Directive0 = + extractScheme.require(_ == name, SchemeRejection(name)) & cancelRejections(classOf[SchemeRejection]) +} + +object SchemeDirectives extends SchemeDirectives { + import BasicDirectives._ + + private val _extractScheme: Directive1[String] = extract(_.request.uri.scheme) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SecurityDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SecurityDirectives.scala new file mode 100644 index 0000000000..a15fa095b8 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/SecurityDirectives.scala @@ -0,0 +1,187 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import scala.reflect.ClassTag +import scala.concurrent.Future +import akka.http.impl.util._ +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.AuthenticationFailedRejection.{ CredentialsRejected, CredentialsMissing } + +/** + * Provides directives for securing an inner route using the standard Http authentication headers [[`WWW-Authenticate`]] + * and [[Authorization]]. Most prominently, HTTP Basic authentication as defined in RFC 2617. + */ +trait SecurityDirectives { + import BasicDirectives._ + import HeaderDirectives._ + import FutureDirectives._ + import RouteDirectives._ + + /** + * The result of an HTTP authentication attempt is either the user object or + * an HttpChallenge to present to the browser. + */ + type AuthenticationResult[+T] = Either[HttpChallenge, T] + + type Authenticator[T] = UserCredentials ⇒ Option[T] + type AsyncAuthenticator[T] = UserCredentials ⇒ Future[Option[T]] + type AuthenticatorPF[T] = PartialFunction[UserCredentials, T] + type AsyncAuthenticatorPF[T] = PartialFunction[UserCredentials, Future[T]] + + /** + * Extracts the potentially present [[HttpCredentials]] provided with the request's [[Authorization]] header. + */ + def extractCredentials: Directive1[Option[HttpCredentials]] = + optionalHeaderValueByType[Authorization]().map(_.map(_.credentials)) + + /** + * A directive that wraps the inner route with Http Basic authentication support. + * The given authenticator determines whether the credentials in the request are valid + * and, if so, which user object to supply to the inner route. + */ + def authenticateBasic[T](realm: String, authenticator: Authenticator[T]): AuthenticationDirective[T] = + authenticateBasicAsync(realm, cred ⇒ FastFuture.successful(authenticator(cred))) + + /** + * A directive that wraps the inner route with Http Basic authentication support. + * The given authenticator determines whether the credentials in the request are valid + * and, if so, which user object to supply to the inner route. + */ + def authenticateBasicAsync[T](realm: String, authenticator: AsyncAuthenticator[T]): AuthenticationDirective[T] = + extractExecutionContext.flatMap { implicit ec ⇒ + authenticateOrRejectWithChallenge[BasicHttpCredentials, T] { basic ⇒ + authenticator(UserCredentials(basic)).fast.map { + case Some(t) ⇒ AuthenticationResult.success(t) + case None ⇒ AuthenticationResult.failWithChallenge(challengeFor(realm)) + } + } + } + + /** + * A directive that wraps the inner route with Http Basic authentication support. + * The given authenticator determines whether the credentials in the request are valid + * and, if so, which user object to supply to the inner route. + */ + def authenticateBasicPF[T](realm: String, authenticator: AuthenticatorPF[T]): AuthenticationDirective[T] = + authenticateBasic(realm, authenticator.lift) + + /** + * A directive that wraps the inner route with Http Basic authentication support. + * The given authenticator determines whether the credentials in the request are valid + * and, if so, which user object to supply to the inner route. + */ + def authenticateBasicPFAsync[T](realm: String, authenticator: AsyncAuthenticatorPF[T]): AuthenticationDirective[T] = + extractExecutionContext.flatMap { implicit ec ⇒ + authenticateBasicAsync(realm, credentials ⇒ + if (authenticator isDefinedAt credentials) authenticator(credentials).fast.map(Some(_)) + else FastFuture.successful(None)) + } + + /** + * Lifts an authenticator function into a directive. The authenticator function gets passed in credentials from the + * [[Authorization]] header of the request. If the function returns ``Right(user)`` the user object is provided + * to the inner route. If the function returns ``Left(challenge)`` the request is rejected with an + * [[AuthenticationFailedRejection]] that contains this challenge to be added to the response. + * + */ + def authenticateOrRejectWithChallenge[T](authenticator: Option[HttpCredentials] ⇒ Future[AuthenticationResult[T]]): AuthenticationDirective[T] = + extractExecutionContext.flatMap { implicit ec ⇒ + extractCredentials.flatMap { cred ⇒ + onSuccess(authenticator(cred)).flatMap { + case Right(user) ⇒ provide(user) + case Left(challenge) ⇒ + val cause = if (cred.isEmpty) CredentialsMissing else CredentialsRejected + reject(AuthenticationFailedRejection(cause, challenge)): Directive1[T] + } + } + } + + /** + * Lifts an authenticator function into a directive. Same as ``authenticateOrRejectWithChallenge`` + * but only applies the authenticator function with a certain type of credentials. + */ + def authenticateOrRejectWithChallenge[C <: HttpCredentials: ClassTag, T]( + authenticator: Option[C] ⇒ Future[AuthenticationResult[T]]): AuthenticationDirective[T] = + authenticateOrRejectWithChallenge[T](cred ⇒ authenticator(cred collect { case c: C ⇒ c })) + + /** + * Applies the given authorization check to the request. + * If the check fails the route is rejected with an [[AuthorizationFailedRejection]]. + */ + def authorize(check: ⇒ Boolean): Directive0 = authorize(_ ⇒ check) + + /** + * Applies the given authorization check to the request. + * If the check fails the route is rejected with an [[AuthorizationFailedRejection]]. + */ + def authorize(check: RequestContext ⇒ Boolean): Directive0 = + extract(check).flatMap[Unit](if (_) pass else reject(AuthorizationFailedRejection)) & + cancelRejection(AuthorizationFailedRejection) + + /** + * Creates a ``Basic`` [[HttpChallenge]] for the given realm. + */ + def challengeFor(realm: String) = HttpChallenge(scheme = "Basic", realm = realm, params = Map.empty) +} + +object SecurityDirectives extends SecurityDirectives + +/** + * Represents authentication credentials supplied with a request. Credentials can either be + * [[UserCredentials.Missing]] or can be [[UserCredentials.Provided]] in which case a username is + * supplied and a function to check the known secret against the provided one in a secure fashion. + */ +sealed trait UserCredentials +object UserCredentials { + case object Missing extends UserCredentials + abstract case class Provided(username: String) extends UserCredentials { + def verifySecret(secret: String): Boolean + } + + def apply(cred: Option[BasicHttpCredentials]): UserCredentials = + cred match { + case Some(BasicHttpCredentials(username, receivedSecret)) ⇒ + new UserCredentials.Provided(username) { + def verifySecret(secret: String): Boolean = secret secure_== receivedSecret + } + case None ⇒ UserCredentials.Missing + } +} + +import SecurityDirectives._ + +object AuthenticationResult { + def success[T](user: T): AuthenticationResult[T] = Right(user) + def failWithChallenge(challenge: HttpChallenge): AuthenticationResult[Nothing] = Left(challenge) +} + +trait AuthenticationDirective[T] extends Directive1[T] { + import BasicDirectives._ + import RouteDirectives._ + + /** + * Returns a copy of this [[AuthenticationDirective]] that will provide ``Some(user)`` if credentials + * were supplied and otherwise ``None``. + */ + def optional: Directive1[Option[T]] = + this.map(Some(_): Option[T]) recover { + case AuthenticationFailedRejection(CredentialsMissing, _) +: _ ⇒ provide(None) + case rejs ⇒ reject(rejs: _*) + } + + /** + * Returns a copy of this [[AuthenticationDirective]] that uses the given object as the + * anonymous user which will be used if no credentials were supplied in the request. + */ + def withAnonymousUser(anonymous: T): Directive1[T] = optional map (_ getOrElse anonymous) +} +object AuthenticationDirective { + implicit def apply[T](other: Directive1[T]): AuthenticationDirective[T] = + new AuthenticationDirective[T] { def tapply(inner: Tuple1[T] ⇒ Route) = other.tapply(inner) } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala new file mode 100644 index 0000000000..ab02d28611 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.scaladsl.model.ws.{ UpgradeToWebsocket, Message } +import akka.stream.scaladsl.Flow + +trait WebsocketDirectives { + import BasicDirectives._ + import RouteDirectives._ + import HeaderDirectives._ + + /** + * Handles websocket requests with the given handler and rejects other requests with a + * [[ExpectedWebsocketRequestRejection]]. + */ + def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route = + extractFlowMaterializer { implicit mat ⇒ + optionalHeaderValueByType[UpgradeToWebsocket]() { + case Some(upgrade) ⇒ complete(upgrade.handleMessages(handler)) + case None ⇒ reject(ExpectedWebsocketRequestRejection) + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/package.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/package.scala new file mode 100644 index 0000000000..a28e0f9565 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/package.scala @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl + +import scala.concurrent.Future + +package object server { + + type Route = RequestContext ⇒ Future[RouteResult] + type RouteGenerator[T] = T ⇒ Route + type Directive0 = Directive[Unit] + type Directive1[T] = Directive[Tuple1[T]] + type PathMatcher0 = PathMatcher[Unit] + type PathMatcher1[T] = PathMatcher[Tuple1[T]] + + def FIXME = throw new RuntimeException("Not yet implemented") +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/ApplyConverter.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ApplyConverter.scala new file mode 100644 index 0000000000..36ae741324 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ApplyConverter.scala @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import akka.http.scaladsl.server._ + +/** + * ApplyConverter allows generic conversion of functions of type `(T1, T2, ...) => Route` to + * `(TupleX(T1, T2, ...)) => Route`. + */ +abstract class ApplyConverter[L] { + type In + def apply(f: In): L ⇒ Route +} + +object ApplyConverter extends ApplyConverterInstances \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/BinaryPolyFunc.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/BinaryPolyFunc.scala new file mode 100644 index 0000000000..851e85912d --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/BinaryPolyFunc.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +/** + * Allows the definition of binary poly-functions (e.g. for folding over tuples). + * + * Note: the poly-function implementation seen here is merely a stripped down version of + * what Miles Sabin made available with his awesome shapeless library. All credit goes to him! + */ +trait BinaryPolyFunc { + def at[A, B] = new CaseBuilder[A, B] + class CaseBuilder[A, B] { + def apply[R](f: (A, B) ⇒ R) = new BinaryPolyFunc.Case[A, B, BinaryPolyFunc.this.type] { + type Out = R + def apply(a: A, b: B) = f(a, b) + } + } +} + +object BinaryPolyFunc { + sealed trait Case[A, B, Op] { + type Out + def apply(a: A, b: B): Out + } +} + diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/ClassMagnet.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ClassMagnet.scala new file mode 100644 index 0000000000..ec53ec0fd8 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ClassMagnet.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +import scala.reflect.ClassTag + +/** A magnet that wraps a ClassTag */ +trait ClassMagnet[T] { + def classTag: ClassTag[T] + def runtimeClass: Class[T] + + /** + * Returns a partial function that checks if the input value is of runtime type + * T and returns the value if it does. Doesn't take erased information into account. + */ + def extractPF: PartialFunction[Any, T] +} +object ClassMagnet { + implicit def apply[T](c: Class[T]): ClassMagnet[T] = ClassMagnet()(ClassTag(c)) + + implicit def apply[T](u: Unit)(implicit tag: ClassTag[T]): ClassMagnet[T] = + new ClassMagnet[T] { + val classTag: ClassTag[T] = tag + val runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]] + val extractPF: PartialFunction[Any, T] = { + case x: T ⇒ x + } + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/ConstructFromTuple.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ConstructFromTuple.scala new file mode 100644 index 0000000000..02696bcee7 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/ConstructFromTuple.scala @@ -0,0 +1,12 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +/** + * Constructor for instances of type ``R`` which can be created from a tuple of type ``T``. + */ +trait ConstructFromTuple[T, R] extends (T ⇒ R) + +object ConstructFromTuple extends ConstructFromTupleInstances diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tuple.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tuple.scala new file mode 100644 index 0000000000..7ab96e8886 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tuple.scala @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +/** + * Phantom type providing implicit evidence that a given type is a Tuple or Unit. + */ +sealed trait Tuple[T] + +object Tuple { + /** + * Used to provide "is-Tuple" evidence where we know that a given value must be a tuple. + */ + def yes[T]: Tuple[T] = null + + implicit def forNothing[A]: Tuple[Nothing] = null + implicit def forUnit[A]: Tuple[Unit] = null + implicit def forTuple1[A]: Tuple[Tuple1[A]] = null + implicit def forTuple2[A, B]: Tuple[(A, B)] = null + implicit def forTuple3[A, B, C]: Tuple[(A, B, C)] = null + implicit def forTuple4[A, B, C, D]: Tuple[(A, B, C, D)] = null + implicit def forTuple5[A, B, C, D, E]: Tuple[(A, B, C, D, E)] = null + implicit def forTuple6[A, B, C, D, E, F]: Tuple[(A, B, C, D, E, F)] = null + implicit def forTuple7[A, B, C, D, E, F, G]: Tuple[(A, B, C, D, E, F, G)] = null + implicit def forTuple8[A, B, C, D, E, F, G, H]: Tuple[(A, B, C, D, E, F, G, H)] = null + implicit def forTuple9[A, B, C, D, E, F, G, H, I]: Tuple[(A, B, C, D, E, F, G, H, I)] = null + implicit def forTuple10[A, B, C, D, E, F, G, H, I, J]: Tuple[(A, B, C, D, E, F, G, H, I, J)] = null + implicit def forTuple11[A, B, C, D, E, F, G, H, I, J, K]: Tuple[(A, B, C, D, E, F, G, H, I, J, K)] = null + implicit def forTuple12[A, B, C, D, E, F, G, H, I, J, K, L]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L)] = null + implicit def forTuple13[A, B, C, D, E, F, G, H, I, J, K, L, M]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M)] = null + implicit def forTuple14[A, B, C, D, E, F, G, H, I, J, K, L, M, N]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N)] = null + implicit def forTuple15[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)] = null + implicit def forTuple16[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)] = null + implicit def forTuple17[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q)] = null + implicit def forTuple18[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R)] = null + implicit def forTuple19[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S)] = null + implicit def forTuple20[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T)] = null + implicit def forTuple21[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U)] = null + implicit def forTuple22[A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V]: Tuple[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V)] = null +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/TupleOps.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/TupleOps.scala new file mode 100644 index 0000000000..2df524f102 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/TupleOps.scala @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +class TupleOps[T](val tuple: T) extends AnyVal { + import TupleOps._ + + /** + * Appends the given value to the tuple producing a tuple of arity n + 1. + */ + def append[S](value: S)(implicit ao: AppendOne[T, S]): ao.Out = ao(tuple, value) + + /** + * Left-Folds over the tuple using the given binary poly-function. + */ + def foldLeft[In](zero: In)(op: BinaryPolyFunc)(implicit fold: FoldLeft[In, T, op.type]): fold.Out = fold(zero, tuple) + + /** + * Appends the given tuple to the underlying tuple producing a tuple of arity n + m. + */ + def join[S](suffixTuple: S)(implicit join: Join[T, S]): join.Out = join(tuple, suffixTuple) +} + +object TupleOps { + implicit def enhanceTuple[T: Tuple](tuple: T) = new TupleOps(tuple) + + trait AppendOne[P, S] { + type Out + def apply(prefix: P, last: S): Out + } + object AppendOne extends TupleAppendOneInstances + + trait FoldLeft[In, T, Op] { + type Out + def apply(zero: In, tuple: T): Out + } + object FoldLeft extends TupleFoldInstances + + trait Join[P, S] { + type Out + def apply(prefix: P, suffix: S): Out + } + object Join extends LowLevelJoinImplicits { + // O(1) shortcut for the Join[Unit, T] case to avoid O(n) runtime in this case + implicit def join0P[T] = + new Join[Unit, T] { + type Out = T + def apply(prefix: Unit, suffix: T): Out = suffix + } + // we implement the join by folding over the suffix with the prefix as growing accumulator + object Fold extends BinaryPolyFunc { + implicit def step[T, A](implicit append: AppendOne[T, A]) = at[T, A](append(_, _)) + } + } + sealed abstract class LowLevelJoinImplicits { + implicit def join[P, S](implicit fold: FoldLeft[P, S, Join.Fold.type]) = + new Join[P, S] { + type Out = fold.Out + def apply(prefix: P, suffix: S): Out = fold(prefix, suffix) + } + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tupler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tupler.scala new file mode 100644 index 0000000000..6a1e932fea --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/util/Tupler.scala @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.server.util + +/** + * Provides a way to convert a value into an Tuple. + * If the value is already a Tuple then it is returned unchanged, otherwise it's wrapped in a Tuple1 instance. + */ +trait Tupler[T] { + type Out + def OutIsTuple: Tuple[Out] + def apply(value: T): Out +} + +object Tupler extends LowerPriorityTupler { + implicit def forTuple[T: Tuple] = + new Tupler[T] { + type Out = T + def OutIsTuple = implicitly[Tuple[Out]] + def apply(value: T) = value + } +} + +private[server] abstract class LowerPriorityTupler { + implicit def forAnyRef[T] = + new Tupler[T] { + type Out = Tuple1[T] + def OutIsTuple = implicitly[Tuple[Out]] + def apply(value: T) = Tuple1(value) + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/GenericUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/GenericUnmarshallers.scala new file mode 100644 index 0000000000..c129410640 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/GenericUnmarshallers.scala @@ -0,0 +1,29 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import akka.http.scaladsl.util.FastFuture + +trait GenericUnmarshallers extends LowerPriorityGenericUnmarshallers { + + implicit def liftToTargetOptionUnmarshaller[A, B](um: Unmarshaller[A, B]): Unmarshaller[A, Option[B]] = + targetOptionUnmarshaller(um) + implicit def targetOptionUnmarshaller[A, B](implicit um: Unmarshaller[A, B]): Unmarshaller[A, Option[B]] = + um map (Some(_)) withDefaultValue None +} + +sealed trait LowerPriorityGenericUnmarshallers { + + implicit def messageUnmarshallerFromEntityUnmarshaller[T](implicit um: FromEntityUnmarshaller[T]): FromMessageUnmarshaller[T] = + Unmarshaller { implicit ec ⇒ request ⇒ um(request.entity) } + + implicit def liftToSourceOptionUnmarshaller[A, B](um: Unmarshaller[A, B]): Unmarshaller[Option[A], B] = + sourceOptionUnmarshaller(um) + implicit def sourceOptionUnmarshaller[A, B](implicit um: Unmarshaller[A, B]): Unmarshaller[Option[A], B] = + Unmarshaller(implicit ec ⇒ { + case Some(a) ⇒ um(a) + case None ⇒ FastFuture.failed(Unmarshaller.NoContentException) + }) +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala new file mode 100644 index 0000000000..e5a6dfe317 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/MultipartUnmarshallers.scala @@ -0,0 +1,105 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import scala.collection.immutable +import scala.collection.immutable.VectorBuilder +import scala.concurrent.ExecutionContext +import akka.util.ByteString +import akka.event.{ NoLogging, LoggingAdapter } +import akka.stream.impl.fusing.IteratorInterpreter +import akka.stream.scaladsl._ +import akka.http.impl.engine.parsing.BodyPartParser +import akka.http.impl.util._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.util.FastFuture +import MediaRanges._ +import MediaTypes._ +import HttpCharsets._ + +trait MultipartUnmarshallers { + + implicit def defaultMultipartGeneralUnmarshaller(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[Multipart.General] = + multipartGeneralUnmarshaller(`UTF-8`) + def multipartGeneralUnmarshaller(defaultCharset: HttpCharset)(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[Multipart.General] = + multipartUnmarshaller[Multipart.General, Multipart.General.BodyPart, Multipart.General.BodyPart.Strict]( + mediaRange = `multipart/*`, + defaultContentType = ContentTypes.`text/plain` withCharset defaultCharset, + createBodyPart = Multipart.General.BodyPart(_, _), + createStreamed = Multipart.General(_, _), + createStrictBodyPart = Multipart.General.BodyPart.Strict, + createStrict = Multipart.General.Strict) + + implicit def multipartFormDataUnmarshaller(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[Multipart.FormData] = + multipartUnmarshaller[Multipart.FormData, Multipart.FormData.BodyPart, Multipart.FormData.BodyPart.Strict]( + mediaRange = `multipart/form-data`, + defaultContentType = ContentTypes.`application/octet-stream`, + createBodyPart = (entity, headers) ⇒ Multipart.General.BodyPart(entity, headers).toFormDataBodyPart.get, + createStreamed = (_, parts) ⇒ Multipart.FormData(parts), + createStrictBodyPart = (entity, headers) ⇒ Multipart.General.BodyPart.Strict(entity, headers).toFormDataBodyPart.get, + createStrict = (_, parts) ⇒ Multipart.FormData.Strict(parts)) + + implicit def defaultMultipartByteRangesUnmarshaller(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[Multipart.ByteRanges] = + multipartByteRangesUnmarshaller(`UTF-8`) + def multipartByteRangesUnmarshaller(defaultCharset: HttpCharset)(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[Multipart.ByteRanges] = + multipartUnmarshaller[Multipart.ByteRanges, Multipart.ByteRanges.BodyPart, Multipart.ByteRanges.BodyPart.Strict]( + mediaRange = `multipart/byteranges`, + defaultContentType = ContentTypes.`text/plain` withCharset defaultCharset, + createBodyPart = (entity, headers) ⇒ Multipart.General.BodyPart(entity, headers).toByteRangesBodyPart.get, + createStreamed = (_, parts) ⇒ Multipart.ByteRanges(parts), + createStrictBodyPart = (entity, headers) ⇒ Multipart.General.BodyPart.Strict(entity, headers).toByteRangesBodyPart.get, + createStrict = (_, parts) ⇒ Multipart.ByteRanges.Strict(parts)) + + def multipartUnmarshaller[T <: Multipart, BP <: Multipart.BodyPart, BPS <: Multipart.BodyPart.Strict](mediaRange: MediaRange, + defaultContentType: ContentType, + createBodyPart: (BodyPartEntity, List[HttpHeader]) ⇒ BP, + createStreamed: (MultipartMediaType, Source[BP, Any]) ⇒ T, + createStrictBodyPart: (HttpEntity.Strict, List[HttpHeader]) ⇒ BPS, + createStrict: (MultipartMediaType, immutable.Seq[BPS]) ⇒ T)(implicit ec: ExecutionContext, log: LoggingAdapter = NoLogging): FromEntityUnmarshaller[T] = + Unmarshaller { implicit ec ⇒ + entity ⇒ + if (entity.contentType.mediaType.isMultipart && mediaRange.matches(entity.contentType.mediaType)) { + entity.contentType.mediaType.params.get("boundary") match { + case None ⇒ + FastFuture.failed(new RuntimeException("Content-Type with a multipart media type must have a 'boundary' parameter")) + case Some(boundary) ⇒ + import BodyPartParser._ + val parser = new BodyPartParser(defaultContentType, boundary, log) + FastFuture.successful { + entity match { + case HttpEntity.Strict(ContentType(mediaType: MultipartMediaType, _), data) ⇒ + val builder = new VectorBuilder[BPS]() + val iter = new IteratorInterpreter[ByteString, BodyPartParser.Output]( + Iterator.single(data), List(parser)).iterator + // note that iter.next() will throw exception if stream fails + iter.foreach { + case BodyPartStart(headers, createEntity) ⇒ + val entity = createEntity(Source.empty) match { + case x: HttpEntity.Strict ⇒ x + case x ⇒ throw new IllegalStateException("Unexpected entity type from strict BodyPartParser: " + x) + } + builder += createStrictBodyPart(entity, headers) + case ParseError(errorInfo) ⇒ throw ParsingException(errorInfo) + case x ⇒ throw new IllegalStateException(s"Unexpected BodyPartParser result $x in strict case") + } + createStrict(mediaType, builder.result()) + case _ ⇒ + val bodyParts = entity.dataBytes + .transform(() ⇒ parser) + .splitWhen(_.isInstanceOf[BodyPartStart]) + .via(headAndTailFlow) + .collect { + case (BodyPartStart(headers, createEntity), entityParts) ⇒ createBodyPart(createEntity(entityParts), headers) + case (ParseError(errorInfo), _) ⇒ throw ParsingException(errorInfo) + } + createStreamed(entity.contentType.mediaType.asInstanceOf[MultipartMediaType], bodyParts) + } + } + } + } else FastFuture.failed(Unmarshaller.UnsupportedContentTypeException(mediaRange)) + } +} + +object MultipartUnmarshallers extends MultipartUnmarshallers diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromEntityUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromEntityUnmarshallers.scala new file mode 100644 index 0000000000..8c587ea9ee --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromEntityUnmarshallers.scala @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import akka.stream.FlowMaterializer +import akka.util.ByteString +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.model._ + +trait PredefinedFromEntityUnmarshallers extends MultipartUnmarshallers { + + implicit def byteStringUnmarshaller(implicit fm: FlowMaterializer): FromEntityUnmarshaller[ByteString] = + Unmarshaller(_ ⇒ { + case HttpEntity.Strict(_, data) ⇒ FastFuture.successful(data) + case entity ⇒ entity.dataBytes.runFold(ByteString.empty)(_ ++ _) + }) + + implicit def byteArrayUnmarshaller(implicit fm: FlowMaterializer): FromEntityUnmarshaller[Array[Byte]] = + byteStringUnmarshaller.map(_.toArray[Byte]) + + implicit def charArrayUnmarshaller(implicit fm: FlowMaterializer): FromEntityUnmarshaller[Array[Char]] = + byteStringUnmarshaller(fm) mapWithInput { (entity, bytes) ⇒ + val charBuffer = entity.contentType.charset.nioCharset.decode(bytes.asByteBuffer) + val array = new Array[Char](charBuffer.length()) + charBuffer.get(array) + array + } + + implicit def stringUnmarshaller(implicit fm: FlowMaterializer): FromEntityUnmarshaller[String] = + byteStringUnmarshaller(fm) mapWithInput { (entity, bytes) ⇒ + // FIXME: add `ByteString::decodeString(java.nio.Charset): String` overload!!! + bytes.decodeString(entity.contentType.charset.nioCharset.name) // ouch!!! + } + + implicit def defaultUrlEncodedFormDataUnmarshaller(implicit fm: FlowMaterializer): FromEntityUnmarshaller[FormData] = + urlEncodedFormDataUnmarshaller(MediaTypes.`application/x-www-form-urlencoded`) + def urlEncodedFormDataUnmarshaller(ranges: ContentTypeRange*)(implicit fm: FlowMaterializer): FromEntityUnmarshaller[FormData] = + stringUnmarshaller.forContentTypes(ranges: _*).mapWithInput { (entity, string) ⇒ + try { + val nioCharset = entity.contentType.definedCharset.getOrElse(HttpCharsets.`UTF-8`).nioCharset + val query = Uri.Query(string, nioCharset) + FormData(query) + } catch { + case IllegalUriException(info) ⇒ + throw new IllegalArgumentException(info.formatPretty.replace("Query,", "form content,")) + } + } +} + +object PredefinedFromEntityUnmarshallers extends PredefinedFromEntityUnmarshallers \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala new file mode 100644 index 0000000000..c15fa2acd2 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala @@ -0,0 +1,74 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +trait PredefinedFromStringUnmarshallers { + + implicit val byteFromStringUnmarshaller = Unmarshaller.strict[String, Byte] { string ⇒ + try string.toByte + catch numberFormatError(string, "8-bit signed integer") + } + + implicit val shortFromStringUnmarshaller = Unmarshaller.strict[String, Short] { string ⇒ + try string.toShort + catch numberFormatError(string, "16-bit signed integer") + } + + implicit val intFromStringUnmarshaller = Unmarshaller.strict[String, Int] { string ⇒ + try string.toInt + catch numberFormatError(string, "32-bit signed integer") + } + + implicit val longFromStringUnmarshaller = Unmarshaller.strict[String, Long] { string ⇒ + try string.toLong + catch numberFormatError(string, "64-bit signed integer") + } + + val HexByte = Unmarshaller.strict[String, Byte] { string ⇒ + try java.lang.Byte.parseByte(string, 16) + catch numberFormatError(string, "8-bit hexadecimal integer") + } + + val HexShort = Unmarshaller.strict[String, Short] { string ⇒ + try java.lang.Short.parseShort(string, 16) + catch numberFormatError(string, "16-bit hexadecimal integer") + } + + val HexInt = Unmarshaller.strict[String, Int] { string ⇒ + try java.lang.Integer.parseInt(string, 16) + catch numberFormatError(string, "32-bit hexadecimal integer") + } + + val HexLong = Unmarshaller.strict[String, Long] { string ⇒ + try java.lang.Long.parseLong(string, 16) + catch numberFormatError(string, "64-bit hexadecimal integer") + } + + implicit val floatFromStringUnmarshaller = Unmarshaller.strict[String, Float] { string ⇒ + try string.toFloat + catch numberFormatError(string, "32-bit floating point") + } + + implicit val doubleFromStringUnmarshaller = Unmarshaller.strict[String, Double] { string ⇒ + try string.toDouble + catch numberFormatError(string, "64-bit floating point") + } + + implicit val booleanFromStringUnmarshaller = Unmarshaller.strict[String, Boolean] { string ⇒ + string.toLowerCase match { + case "true" | "yes" | "on" ⇒ true + case "false" | "no" | "off" ⇒ false + case "" ⇒ throw Unmarshaller.NoContentException + case x ⇒ throw new IllegalArgumentException(s"'$x' is not a valid Boolean value") + } + } + + private def numberFormatError(value: String, target: String): PartialFunction[Throwable, Nothing] = { + case e: NumberFormatException ⇒ + throw if (value.isEmpty) Unmarshaller.NoContentException else new IllegalArgumentException(s"'$value' is not a valid $target value", e) + } +} + +object PredefinedFromStringUnmarshallers extends PredefinedFromStringUnmarshallers diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshal.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshal.scala new file mode 100644 index 0000000000..c903f7914a --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshal.scala @@ -0,0 +1,18 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import scala.concurrent.{ ExecutionContext, Future } + +object Unmarshal { + def apply[T](value: T): Unmarshal[T] = new Unmarshal(value) +} + +class Unmarshal[A](val value: A) { + /** + * Unmarshals the value to the given Type using the in-scope Unmarshaller. + */ + def to[B](implicit um: Unmarshaller[A, B], ec: ExecutionContext): Future[B] = um(value) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshaller.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshaller.scala new file mode 100644 index 0000000000..8ec4080458 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/Unmarshaller.scala @@ -0,0 +1,122 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl.unmarshalling + +import scala.util.control.{ NoStackTrace, NonFatal } +import scala.concurrent.{ Future, ExecutionContext } +import akka.http.scaladsl.util.FastFuture +import akka.http.scaladsl.util.FastFuture._ +import akka.http.scaladsl.model._ + +trait Unmarshaller[-A, B] { + + def apply(value: A)(implicit ec: ExecutionContext): Future[B] + + def transform[C](f: ExecutionContext ⇒ Future[B] ⇒ Future[C]): Unmarshaller[A, C] = + Unmarshaller { implicit ec ⇒ a ⇒ f(ec)(this(a)) } + + def map[C](f: B ⇒ C): Unmarshaller[A, C] = + transform(implicit ec ⇒ _.fast map f) + + def flatMap[C](f: ExecutionContext ⇒ B ⇒ Future[C]): Unmarshaller[A, C] = + transform(implicit ec ⇒ _.fast flatMap f(ec)) + + def recover[C >: B](pf: ExecutionContext ⇒ PartialFunction[Throwable, C]): Unmarshaller[A, C] = + transform(implicit ec ⇒ _.fast recover pf(ec)) + + def withDefaultValue[BB >: B](defaultValue: BB): Unmarshaller[A, BB] = + recover(_ ⇒ { case Unmarshaller.NoContentException ⇒ defaultValue }) +} + +object Unmarshaller + extends GenericUnmarshallers + with PredefinedFromEntityUnmarshallers + with PredefinedFromStringUnmarshallers { + + /** + * Creates an `Unmarshaller` from the given function. + */ + def apply[A, B](f: ExecutionContext ⇒ A ⇒ Future[B]): Unmarshaller[A, B] = + new Unmarshaller[A, B] { + def apply(a: A)(implicit ec: ExecutionContext) = + try f(ec)(a) + catch { case NonFatal(e) ⇒ FastFuture.failed(e) } + } + + /** + * Helper for creating a synchronous `Unmarshaller` from the given function. + */ + def strict[A, B](f: A ⇒ B): Unmarshaller[A, B] = Unmarshaller(_ ⇒ a ⇒ FastFuture.successful(f(a))) + + /** + * Helper for creating a "super-unmarshaller" from a sequence of "sub-unmarshallers", which are tried + * in the given order. The first successful unmarshalling of a "sub-unmarshallers" is the one produced by the + * "super-unmarshaller". + */ + def firstOf[A, B](unmarshallers: Unmarshaller[A, B]*): Unmarshaller[A, B] = + Unmarshaller { implicit ec ⇒ + a ⇒ + def rec(ix: Int, supported: Set[ContentTypeRange]): Future[B] = + if (ix < unmarshallers.size) { + unmarshallers(ix)(a).fast.recoverWith { + case Unmarshaller.UnsupportedContentTypeException(supp) ⇒ rec(ix + 1, supported ++ supp) + } + } else FastFuture.failed(Unmarshaller.UnsupportedContentTypeException(supported)) + rec(0, Set.empty) + } + + implicit def identityUnmarshaller[T]: Unmarshaller[T, T] = Unmarshaller(_ ⇒ FastFuture.successful) + + // we don't define these methods directly on `Unmarshaller` due to variance constraints + implicit class EnhancedUnmarshaller[A, B](val um: Unmarshaller[A, B]) extends AnyVal { + def mapWithInput[C](f: (A, B) ⇒ C): Unmarshaller[A, C] = + Unmarshaller(implicit ec ⇒ a ⇒ um(a).fast.map(f(a, _))) + + def flatMapWithInput[C](f: (A, B) ⇒ Future[C]): Unmarshaller[A, C] = + Unmarshaller(implicit ec ⇒ a ⇒ um(a).fast.flatMap(f(a, _))) + } + + implicit class EnhancedFromEntityUnmarshaller[A](val underlying: FromEntityUnmarshaller[A]) extends AnyVal { + def mapWithCharset[B](f: (A, HttpCharset) ⇒ B): FromEntityUnmarshaller[B] = + underlying.mapWithInput { (entity, data) ⇒ f(data, entity.contentType.charset) } + + /** + * Modifies the underlying [[Unmarshaller]] to only accept content-types matching one of the given ranges. + * If the underlying [[Unmarshaller]] already contains a content-type filter (also wrapped at some level), + * this filter is *replaced* by this method, not stacked! + */ + def forContentTypes(ranges: ContentTypeRange*): FromEntityUnmarshaller[A] = + Unmarshaller { implicit ec ⇒ + entity ⇒ + if (entity.contentType == ContentTypes.NoContentType || ranges.exists(_ matches entity.contentType)) { + underlying(entity).fast recoverWith retryWithPatchedContentType(underlying, entity) + } else FastFuture.failed(UnsupportedContentTypeException(ranges: _*)) + } + } + + // must be moved out of the the [[EnhancedFromEntityUnmarshaller]] value class due to bug in scala 2.10: + // https://issues.scala-lang.org/browse/SI-8018 + private def retryWithPatchedContentType[T](underlying: FromEntityUnmarshaller[T], entity: HttpEntity)( + implicit ec: ExecutionContext): PartialFunction[Throwable, Future[T]] = { + case UnsupportedContentTypeException(supported) ⇒ underlying(entity withContentType supported.head.specimen) + } + + /** + * Signals that unmarshalling failed because the entity was unexpectedly empty. + */ + case object NoContentException extends RuntimeException("Message entity must not be empty") with NoStackTrace + + /** + * Signals that unmarshalling failed because the entity content-type did not match one of the supported ranges. + * This error cannot be thrown by custom code, you need to use the `forContentTypes` modifier on a base + * [[akka.http.scaladsl.unmarshalling.Unmarshaller]] instead. + */ + final case class UnsupportedContentTypeException(supported: Set[ContentTypeRange]) + extends RuntimeException(supported.mkString("Unsupported Content-Type, supported: ", ", ", "")) + + object UnsupportedContentTypeException { + def apply(supported: ContentTypeRange*): UnsupportedContentTypeException = UnsupportedContentTypeException(Set(supported: _*)) + } +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/package.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/package.scala new file mode 100644 index 0000000000..79b4ecf3d5 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/package.scala @@ -0,0 +1,17 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.scaladsl + +import akka.http.scaladsl.common.StrictForm +import akka.http.scaladsl.model._ + +package object unmarshalling { + type FromEntityUnmarshaller[T] = Unmarshaller[HttpEntity, T] + type FromMessageUnmarshaller[T] = Unmarshaller[HttpMessage, T] + type FromResponseUnmarshaller[T] = Unmarshaller[HttpResponse, T] + type FromRequestUnmarshaller[T] = Unmarshaller[HttpRequest, T] + type FromStringUnmarshaller[T] = Unmarshaller[String, T] + type FromStrictFormFieldUnmarshaller[T] = Unmarshaller[StrictForm.Field, T] +}