diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/ParameterDirectivesExamplesSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/ParameterDirectivesExamplesSpec.scala old mode 100644 new mode 100755 index 42c4785dc1..036bbb2c12 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/ParameterDirectivesExamplesSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/ParameterDirectivesExamplesSpec.scala @@ -192,9 +192,9 @@ class ParameterDirectivesExamplesSpec extends RoutingSpec with PredefinedFromStr responseAs[String] shouldEqual "The parameters are x = '1', x = '2'" } } - "csv" in { + "csv string sequence" in { val route = - parameter("names".as(CsvString)) { names => + parameter("names".as(CsvStringSeq)) { names => complete(s"The parameters are ${names.mkString(", ")}") } @@ -206,4 +206,48 @@ class ParameterDirectivesExamplesSpec extends RoutingSpec with PredefinedFromStr responseAs[String] shouldEqual "The parameters are Caplin, John" } } + "csv byte sequence" in { + val route = + parameter("numbers".as(CsvByteSeq)) { bytes => + complete(s"The numbers are ${bytes.mkString(", ")}") + } + + // tests: + Get(s"/?numbers=2,${Byte.MaxValue}") ~> route ~> check { + responseAs[String] shouldEqual s"The numbers are 2, ${Byte.MaxValue}" + } + } + "csv short sequence" in { + val route = + parameter("numbers".as(CsvShortSeq)) { shorts => + complete(s"The numbers are ${shorts.mkString(", ")}") + } + + // tests: + Get(s"/?numbers=2,${Short.MaxValue}") ~> route ~> check { + responseAs[String] shouldEqual s"The numbers are 2, ${Short.MaxValue}" + } + } + "csv int sequence" in { + val route = + parameter("numbers".as(CsvIntSeq)) { ints => + complete(s"The numbers are ${ints.mkString(", ")}") + } + + // tests: + Get(s"/?numbers=2,${Int.MaxValue}") ~> route ~> check { + responseAs[String] shouldEqual s"The numbers are 2, ${Int.MaxValue}" + } + } + "csv long sequence" in { + val route = + parameter("numbers".as(CsvLongSeq)) { longs => + complete(s"The numbers are ${longs.mkString(", ")}") + } + + // tests: + Get(s"/?numbers=2,${Long.MaxValue}") ~> route ~> check { + responseAs[String] shouldEqual s"The numbers are 2, ${Long.MaxValue}" + } + } } diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala old mode 100644 new mode 100755 index fbca6989e3..c4332e0f00 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ParameterDirectivesSpec.scala @@ -7,7 +7,7 @@ package directives import org.scalatest.{ FreeSpec, Inside } import akka.http.scaladsl.unmarshalling.Unmarshaller.HexInt -import akka.http.scaladsl.unmarshalling.Unmarshaller.CsvString +import akka.http.scaladsl.unmarshalling.Unmarshaller.CsvStringSeq class ParameterDirectivesSpec extends FreeSpec with GenericRoutingSpec with Inside { "when used with 'as[Int]' the parameter directive should" - { @@ -53,9 +53,9 @@ class ParameterDirectivesSpec extends FreeSpec with GenericRoutingSpec with Insi } } - "when used with 'as(CsvString)' the parameter directive should" - { + "when used with 'as(CsvStringSeq)' the parameter directive should" - { val route = - parameter("names".as(CsvString)) { names ⇒ + parameter("names".as(CsvStringSeq)) { names ⇒ complete(s"The parameters are ${names.mkString(", ")}") } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala old mode 100644 new mode 100755 index 8894ce96d3..ff0b635796 --- a/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/unmarshalling/PredefinedFromStringUnmarshallers.scala @@ -9,28 +9,16 @@ import scala.collection.immutable trait PredefinedFromStringUnmarshallers { implicit val byteFromStringUnmarshaller: Unmarshaller[String, Byte] = - Unmarshaller.strict[String, Byte] { string ⇒ - try string.toByte - catch numberFormatError(string, "8-bit signed integer") - } + Unmarshaller.strict(byteFromString) implicit val shortFromStringUnmarshaller: Unmarshaller[String, Short] = - Unmarshaller.strict[String, Short] { string ⇒ - try string.toShort - catch numberFormatError(string, "16-bit signed integer") - } + Unmarshaller.strict(shortFromString) implicit val intFromStringUnmarshaller: Unmarshaller[String, Int] = - Unmarshaller.strict[String, Int] { string ⇒ - try string.toInt - catch numberFormatError(string, "32-bit signed integer") - } + Unmarshaller.strict(intFromString) implicit val longFromStringUnmarshaller: Unmarshaller[String, Long] = - Unmarshaller.strict[String, Long] { string ⇒ - try string.toLong - catch numberFormatError(string, "64-bit signed integer") - } + Unmarshaller.strict(longFromString) val HexByte: Unmarshaller[String, Byte] = Unmarshaller.strict[String, Byte] { string ⇒ @@ -78,11 +66,39 @@ trait PredefinedFromStringUnmarshallers { } } - val CsvString: Unmarshaller[String, immutable.Seq[String]] = + val CsvStringSeq: Unmarshaller[String, immutable.Seq[String]] = Unmarshaller.strict[String, immutable.Seq[String]] { string ⇒ string.split(",").toList } + val CsvByteSeq: Unmarshaller[String, immutable.Seq[Byte]] = + CsvStringSeq.map(_.map(byteFromString)) + + val CsvShortSeq: Unmarshaller[String, immutable.Seq[Short]] = + CsvStringSeq.map(_.map(shortFromString)) + + val CsvIntSeq: Unmarshaller[String, immutable.Seq[Int]] = + CsvStringSeq.map(_.map(intFromString)) + + val CsvLongSeq: Unmarshaller[String, immutable.Seq[Long]] = + CsvStringSeq.map(_.map(longFromString)) + + private def byteFromString(string: String) = + try string.toByte + catch numberFormatError(string, "8-bit signed integer") + + private def shortFromString(string: String) = + try string.toShort + catch numberFormatError(string, "16-bit signed integer") + + private def intFromString(string: String) = + try string.toInt + catch numberFormatError(string, "32-bit signed integer") + + private def longFromString(string: String) = + try string.toLong + catch numberFormatError(string, "64-bit signed integer") + private def numberFormatError(value: String, target: String): PartialFunction[Throwable, Nothing] = { case e: NumberFormatException ⇒ throw if (value.isEmpty) Unmarshaller.NoContentException else new IllegalArgumentException(s"'$value' is not a valid $target value", e)