diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/BodyPartRenderer.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/BodyPartRenderer.scala index d0768924b9..752686ae49 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/BodyPartRenderer.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/BodyPartRenderer.scala @@ -5,6 +5,8 @@ package akka.http.impl.engine.rendering import java.nio.charset.Charset +import akka.parboiled2.util.Base64 + import scala.collection.immutable import akka.event.LoggingAdapter import akka.http.scaladsl.model._ @@ -16,6 +18,8 @@ import akka.stream.stage._ import akka.util.ByteString import HttpEntity._ +import scala.concurrent.forkjoin.ThreadLocalRandom + /** * INTERNAL API */ @@ -110,4 +114,13 @@ private[http] object BodyPartRenderer { case x ⇒ r ~~ x ~~ CrLf } + + /** + * Creates a new random number of the given length and base64 encodes it (using a custom "safe" alphabet). + */ + def randomBoundary(length: Int = 18, random: java.util.Random = ThreadLocalRandom.current()): String = { + val array = new Array[Byte](length) + random.nextBytes(array) + Base64.custom.encodeToString(array, false) + } } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala index 5cc1f17906..b107cafef1 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala @@ -4,15 +4,21 @@ package akka.http.scaladsl.model +import java.io.File + +import akka.event.{ NoLogging, LoggingAdapter } + import scala.collection.immutable.VectorBuilder import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ Future, ExecutionContext } import scala.collection.immutable import scala.util.{ Failure, Success, Try } import akka.stream.FlowMaterializer -import akka.stream.scaladsl.Source +import akka.stream.io.SynchronousFileSource +import akka.stream.scaladsl.{ FlattenStrategy, Source } import akka.http.scaladsl.util.FastFuture import akka.http.scaladsl.model.headers._ +import akka.http.impl.engine.rendering.BodyPartRenderer import FastFuture._ trait Multipart { @@ -25,12 +31,29 @@ trait Multipart { * The Future is failed with an TimeoutException if one part isn't read completely after the given timeout. */ def toStrict(timeout: FiniteDuration)(implicit ec: ExecutionContext, fm: FlowMaterializer): Future[Multipart.Strict] + + /** + * Creates an entity from this multipart object. + */ + def toEntity(charset: HttpCharset = HttpCharsets.`UTF-8`, + boundary: String = BodyPartRenderer.randomBoundary())(implicit log: LoggingAdapter = NoLogging): MessageEntity = { + val chunks = + parts + .transform(() ⇒ BodyPartRenderer.streamed(boundary, charset.nioCharset, partHeadersSizeHint = 128, log)) + .flatten(FlattenStrategy.concat) + HttpEntity.Chunked(mediaType withBoundary boundary, chunks) + } } object Multipart { trait Strict extends Multipart { def strictParts: immutable.Seq[BodyPart.Strict] + + override def toEntity(charset: HttpCharset, boundary: String)(implicit log: LoggingAdapter = NoLogging): HttpEntity.Strict = { + val data = BodyPartRenderer.strict(strictParts, boundary, charset.nioCharset, partHeadersSizeHint = 128, log) + HttpEntity(mediaType withBoundary boundary, data) + } } trait BodyPart { @@ -154,6 +177,7 @@ object Multipart { } object FormData { def apply(parts: BodyPart.Strict*): Strict = Strict(parts.toVector) + def apply(parts: BodyPart*): FormData = FormData(Source(parts.toVector)) def apply(fields: Map[String, HttpEntity.Strict]): Strict = Strict { fields.map { case (name, entity) ⇒ BodyPart.Strict(name, entity) }(collection.breakOut) @@ -164,6 +188,15 @@ object Multipart { override def toString = s"FormData($parts)" } + /** + * Creates a FormData instance that contains a single part backed by the given file. + * + * To create an instance with several parts or for multiple files, use + * ``FormData(BodyPart.fromFile("field1", ...), BodyPart.fromFile("field2", ...)`` + */ + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): FormData = + FormData(Source.single(BodyPart.fromFile(name, contentType, file, chunkSize))) + /** * Strict [[FormData]]. */ @@ -201,6 +234,12 @@ object Multipart { override def toString = s"FormData.BodyPart($name, $entity, $additionalDispositionParams, $additionalHeaders)" } + /** + * Creates a BodyPart backed by a File that will be streamed using a SynchronousFileSource. + */ + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): BodyPart = + BodyPart(name, HttpEntity(contentType, file, chunkSize), Map("filename" -> file.getName)) + def unapply(value: BodyPart): Option[(String, BodyPartEntity, Map[String, String], immutable.Seq[HttpHeader])] = Some((value.name, value.entity, value.additionalDispositionParams, value.additionalHeaders)) diff --git a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala index 8649f12b92..b38c4afff9 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/marshalling/MultipartMarshallers.scala @@ -5,41 +5,37 @@ package akka.http.scaladsl.marshalling import scala.concurrent.forkjoin.ThreadLocalRandom -import akka.parboiled2.util.Base64 import akka.event.{ NoLogging, LoggingAdapter } -import akka.stream.scaladsl.FlattenStrategy import akka.http.impl.engine.rendering.BodyPartRenderer import akka.http.scaladsl.model._ trait MultipartMarshallers { - protected val multipartBoundaryRandom: java.util.Random = ThreadLocalRandom.current() - - /** - * Creates a new random 144-bit number and base64 encodes it (using a custom "safe" alphabet, yielding 24 characters). - */ - def randomBoundary: String = { - val array = new Array[Byte](18) - multipartBoundaryRandom.nextBytes(array) - Base64.custom.encodeToString(array, false) - } - implicit def multipartMarshaller[T <: Multipart](implicit log: LoggingAdapter = NoLogging): ToEntityMarshaller[T] = Marshaller strict { value ⇒ - val boundary = randomBoundary + val boundary = randomBoundary() val contentType = ContentType(value.mediaType withBoundary boundary) Marshalling.WithOpenCharset(contentType.mediaType, { charset ⇒ - value match { - case x: Multipart.Strict ⇒ - val data = BodyPartRenderer.strict(x.strictParts, boundary, charset.nioCharset, partHeadersSizeHint = 128, log) - HttpEntity(contentType, data) - case _ ⇒ - val chunks = value.parts - .transform(() ⇒ BodyPartRenderer.streamed(boundary, charset.nioCharset, partHeadersSizeHint = 128, log)) - .flatten(FlattenStrategy.concat) - HttpEntity.Chunked(contentType, chunks) - } + value.toEntity(charset, boundary)(log) }) } + + /** + * The random instance that is used to create multipart boundaries. This can be overriden (e.g. in tests) to + * choose how a boundary is created. + */ + protected def multipartBoundaryRandom: java.util.Random = ThreadLocalRandom.current() + + /** + * The length of randomly generated multipart boundaries (before base64 encoding). Can be overridden + * to configure. + */ + protected def multipartBoundaryLength: Int = 18 + + /** + * The method used to create boundaries in `multipartMarshaller`. Can be overridden to create custom boundaries. + */ + protected def randomBoundary(): String = + BodyPartRenderer.randomBoundary(length = multipartBoundaryLength, random = multipartBoundaryRandom) } object MultipartMarshallers extends MultipartMarshallers