diff --git a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClient.scala b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClient.scala index 9020237811..89be13f9be 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClient.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClient.scala @@ -5,16 +5,17 @@ package akka.http.engine.client import java.net.InetSocketAddress -import scala.collection.immutable.Queue +import scala.annotation.tailrec +import scala.collection.mutable.ListBuffer +import akka.stream.stage._ import akka.util.ByteString import akka.event.LoggingAdapter import akka.stream.FlattenStrategy import akka.stream.scaladsl._ import akka.stream.scaladsl.OperationAttributes._ -import akka.http.model.{ HttpMethod, HttpRequest, HttpResponse } +import akka.http.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse } import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory } -import akka.http.engine.parsing.{ HttpHeaderParser, HttpResponseParser } -import akka.http.engine.parsing.ParserOutput._ +import akka.http.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser } import akka.http.util._ /** @@ -37,39 +38,212 @@ private[http] object HttpClient { }) val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log) - val requestMethodByPass = new RequestMethodByPass(remoteAddress) - Flow[HttpRequest] - .map(requestMethodByPass) + /* + Basic Stream Setup + ================== + + requestIn +----------+ + +-----------------------------------------------+--->| Termi- | requestRendering + | | nation +---------------------> | + +-------------------------------------->| Merge | | + | Termination Backchannel | +----------+ | TCP- + | | | level + | | Method | client + | +------------+ | Bypass | flow + responseOut | responsePrep | Response |<---+ | + <------------+----------------| Parsing | | + | Merge |<------------------------------------------ V + +------------+ + */ + + val requestIn = UndefinedSource[HttpRequest] + val responseOut = UndefinedSink[HttpResponse] + + val methodBypassFanout = Broadcast[HttpRequest] + val responseParsingMerge = new ResponseParsingMerge(rootParser) + + val terminationFanout = Broadcast[HttpResponse] + val terminationMerge = new TerminationMerge + + val requestRendering = Flow[HttpRequest] + .map(RequestRenderingContext(_, remoteAddress)) .section(name("renderer"))(_.transform(() ⇒ requestRendererFactory.newRenderer)) .flatten(FlattenStrategy.concat) + + val transportFlow = Flow[ByteString] .section(name("errorLogger"))(_.transform(() ⇒ errorLogger(log, "Outgoing request stream error"))) .via(transport) - .section(name("rootParser"))(_.transform(() ⇒ - // each connection uses a single (private) response parser instance for all its responses - // which builds a cache of all header instances seen on that connection - rootParser.createShallowCopy(requestMethodByPass))) - .splitWhen(_.isInstanceOf[MessageStart]) + + val methodBypass = Flow[HttpRequest].map(_.method) + + import ParserOutput._ + val responsePrep = Flow[List[ResponseOutput]] + .transform(recover { case x: ResponseParsingError ⇒ x.error :: Nil }) // FIXME after #16565 + .mapConcat(identityFunc) + .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .headAndTail .collect { case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts) ⇒ HttpResponse(statusCode, headers, createEntity(entityParts), protocol) + case (MessageStartError(_, info), _) ⇒ throw IllegalResponseException(info) } + + import FlowGraphImplicits._ + + Flow() { implicit b ⇒ + requestIn ~> methodBypassFanout ~> terminationMerge.requestInput ~> requestRendering ~> transportFlow ~> + responseParsingMerge.dataInput ~> responsePrep ~> terminationFanout ~> responseOut + methodBypassFanout ~> methodBypass ~> responseParsingMerge.methodBypassInput + terminationFanout ~> terminationMerge.terminationBackchannelInput + + b.allowCycles() + + requestIn -> responseOut + } } - // FIXME: refactor to a pure-stream design that allows us to get rid of this ad-hoc queue here - class RequestMethodByPass(serverAddress: InetSocketAddress) - extends (HttpRequest ⇒ RequestRenderingContext) with (() ⇒ HttpMethod) { - private[this] var requestMethods = Queue.empty[HttpMethod] - def apply(request: HttpRequest) = { - requestMethods = requestMethods.enqueue(request.method) - RequestRenderingContext(request, serverAddress) + // a simple merge stage that simply forwards its first input and ignores its second input + // (the terminationBackchannelInput), but applies a special completion handling + class TerminationMerge extends FlexiMerge[HttpRequest] { + import FlexiMerge._ + val requestInput = createInputPort[HttpRequest]() + val terminationBackchannelInput = createInputPort[HttpResponse]() + + def createMergeLogic() = new MergeLogic[HttpRequest] { + override def inputHandles(inputCount: Int) = { + require(inputCount == 2, s"TerminationMerge must have 2 connected inputs, was $inputCount") + Vector(requestInput, terminationBackchannelInput) + } + + override def initialState = State[Any](ReadAny(requestInput, terminationBackchannelInput)) { + case (ctx, _, request: HttpRequest) ⇒ { ctx.emit(request); SameState } + case _ ⇒ SameState // simply drop all responses, we are only interested in the completion of the response input + } + + override def initialCompletionHandling = CompletionHandling( + onComplete = { + case (ctx, `requestInput`) ⇒ SameState + case (ctx, `terminationBackchannelInput`) ⇒ + ctx.complete() + SameState + }, + onError = defaultCompletionHandling.onError) } - def apply(): HttpMethod = - if (requestMethods.nonEmpty) { - val method = requestMethods.head - requestMethods = requestMethods.tail - method - } else HttpResponseParser.NoMethod + } + + import ParserOutput._ + + /** + * A FlexiMerge that follows this logic: + * 1. Wait on the methodBypass for the method of the request corresponding to the next response to be received + * 2. Read from the dataInput until exactly one response has been fully received + * 3. Go back to 1. + */ + class ResponseParsingMerge(rootParser: HttpResponseParser) extends FlexiMerge[List[ResponseOutput]] { + import FlexiMerge._ + val dataInput = createInputPort[ByteString]() + val methodBypassInput = createInputPort[HttpMethod]() + + def createMergeLogic() = new MergeLogic[List[ResponseOutput]] { + // each connection uses a single (private) response parser instance for all its responses + // which builds a cache of all header instances seen on that connection + val parser = rootParser.createShallowCopy() + var methodBypassCompleted = false + + override def inputHandles(inputCount: Int) = { + require(inputCount == 2, s"ResponseParsingMerge must have 2 connected inputs, was $inputCount") + Vector(dataInput, methodBypassInput) + } + + override val initialState: State[HttpMethod] = + State(Read(methodBypassInput)) { + case (ctx, _, method) ⇒ + parser.setRequestMethodForNextResponse(method) + drainParser(parser.onPush(ByteString.empty), ctx, + onNeedNextMethod = () ⇒ SameState, + onNeedMoreData = () ⇒ { + ctx.changeCompletionHandling(responseReadingCompletionHandling) + responseReadingState + }) + } + + val responseReadingState: State[ByteString] = + State(Read(dataInput)) { + case (ctx, _, bytes) ⇒ + drainParser(parser.onPush(bytes), ctx, + onNeedNextMethod = () ⇒ { + if (methodBypassCompleted) { + ctx.complete() + SameState + } else { + ctx.changeCompletionHandling(initialCompletionHandling) + initialState + } + }, + onNeedMoreData = () ⇒ SameState) + } + + @tailrec def drainParser(current: ResponseOutput, ctx: MergeLogicContext, + onNeedNextMethod: () ⇒ State[_], onNeedMoreData: () ⇒ State[_], + b: ListBuffer[ResponseOutput] = ListBuffer.empty): State[_] = { + def emit(output: List[ResponseOutput]): Unit = if (output.nonEmpty) ctx.emit(output) + current match { + case NeedNextRequestMethod ⇒ + emit(b.result()) + onNeedNextMethod() + case StreamEnd ⇒ + emit(b.result()) + ctx.complete() + SameState + case NeedMoreData ⇒ + emit(b.result()) + onNeedMoreData() + case x ⇒ drainParser(parser.onPull(), ctx, onNeedNextMethod, onNeedMoreData, b += x) + } + } + + override val initialCompletionHandling = CompletionHandling( + onComplete = (ctx, _) ⇒ { ctx.complete(); SameState }, + onError = defaultCompletionHandling.onError) + + val responseReadingCompletionHandling = CompletionHandling( + onComplete = { + case (ctx, `methodBypassInput`) ⇒ + methodBypassCompleted = true + SameState + case (ctx, `dataInput`) ⇒ + if (parser.onUpstreamFinish()) { + ctx.complete() + } else { + // not pretty but because the FlexiMerge doesn't let us emit from here (#16565) + // we need to funnel the error through the error channel + ctx.error(new ResponseParsingError(parser.onPull().asInstanceOf[ErrorOutput])) + } + SameState + }, + onError = defaultCompletionHandling.onError) + } + } + + private class ResponseParsingError(val error: ErrorOutput) extends RuntimeException + + // TODO: remove after #16394 is cleared + def recover[A, B >: A](pf: PartialFunction[Throwable, B]): () ⇒ PushPullStage[A, B] = { + val stage = new PushPullStage[A, B] { + var recovery: Option[B] = None + def onPush(elem: A, ctx: Context[B]): Directive = ctx.push(elem) + def onPull(ctx: Context[B]): Directive = recovery match { + case None ⇒ ctx.pull() + case Some(x) ⇒ { recovery = null; ctx.push(x) } + case null ⇒ ctx.finish() + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[B]): TerminationDirective = + if (pf isDefinedAt cause) { + recovery = Some(pf(cause)) + ctx.absorbTermination() + } else super.onUpstreamFailure(cause, ctx) + } + () ⇒ stage } } \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala index ef3367ee8a..c2288ccdf0 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala @@ -10,7 +10,6 @@ import akka.parboiled2.CharUtils import akka.util.ByteString import akka.stream.scaladsl.Source import akka.stream.stage._ -import akka.http.Http.StreamException import akka.http.model.parser.CharacterClasses import akka.http.util._ import akka.http.model._ @@ -22,21 +21,33 @@ import ParserOutput._ * INTERNAL API */ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings, - val headerParser: HttpHeaderParser) - extends PushPullStage[ByteString, Output] { + val headerParser: HttpHeaderParser) { self ⇒ import HttpMessageParser._ import settings._ - sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup - final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult - private[this] val result = new ListBuffer[Output] private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0) private[this] var protocol: HttpProtocol = `HTTP/1.1` private[this] var completionHandling: CompletionHandling = CompletionOk private[this] var terminated = false - override def onPush(input: ByteString, ctx: Context[Output]): Directive = { + def isTerminated = terminated + + val stage: PushPullStage[ByteString, Output] = + new PushPullStage[ByteString, Output] { + def onPush(elem: ByteString, ctx: Context[Output]) = handleParserOutput(self.onPush(elem), ctx) + def onPull(ctx: Context[Output]) = handleParserOutput(self.onPull(), ctx) + private def handleParserOutput(output: Output, ctx: Context[Output]): Directive = + output match { + case StreamEnd ⇒ ctx.finish() + case NeedMoreData ⇒ ctx.pull() + case x ⇒ ctx.push(x) + } + override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective = + if (self.onUpstreamFinish()) ctx.finish() else ctx.absorbTermination() + } + + final def onPush(input: ByteString): Output = { @tailrec def run(next: ByteString ⇒ StateResult): StateResult = (try next(input) catch { @@ -51,37 +62,32 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`") run(state) - pushResultHeadAndFinishOrPull(ctx) + onPull() } - def onPull(ctx: Context[Output]): Directive = pushResultHeadAndFinishOrPull(ctx) - - def pushResultHeadAndFinishOrPull(ctx: Context[Output]): Directive = + final def onPull(): Output = if (result.nonEmpty) { val head = result.head result.remove(0) // faster than `ListBuffer::drop` - ctx.push(head) - } else if (terminated) ctx.finish() else ctx.pull() + head + } else if (terminated) StreamEnd else NeedMoreData - override def onUpstreamFinish(ctx: Context[Output]) = { + final def onUpstreamFinish(): Boolean = { completionHandling() match { - case Some(x) ⇒ emit(x.asInstanceOf[Output]) + case Some(x) ⇒ emit(x) case None ⇒ // nothing to do } terminated = true - if (result.isEmpty) ctx.finish() else ctx.absorbTermination() + result.isEmpty } - def startNewMessage(input: ByteString, offset: Int): StateResult = { - def _startNewMessage(input: ByteString, offset: Int): StateResult = - try parseMessage(input, offset) - catch { case NotEnoughDataException ⇒ continue(input, offset)(_startNewMessage) } - + protected final def startNewMessage(input: ByteString, offset: Int): StateResult = { if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError) - _startNewMessage(input, offset) + try parseMessage(input, offset) + catch { case NotEnoughDataException ⇒ continue(input, offset)(startNewMessage) } } - def parseMessage(input: ByteString, offset: Int): StateResult + protected def parseMessage(input: ByteString, offset: Int): StateResult def parseProtocol(input: ByteString, cursor: Int): Int = { def c(ix: Int) = byteChar(input, cursor + ix) @@ -204,7 +210,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser val chunkBodyEnd = cursor + chunkSize def result(terminatorLen: Int) = { emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) - trampoline(_ ⇒ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) + Trampoline(_ ⇒ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) } byteChar(input, chunkBodyEnd) match { case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' ⇒ result(2) @@ -255,7 +261,6 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser state = next(_, 0) done() } - def trampoline(next: ByteString ⇒ StateResult): StateResult = Trampoline(next) def failMessageStart(summary: String): StateResult = failMessageStart(summary, "") def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail) @@ -299,7 +304,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser transformData: Source[ByteString] ⇒ Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = { val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes - case EntityStreamError(info) ⇒ throw new StreamException(info) + case EntityStreamError(info) ⇒ throw EntityStreamException(info) } HttpEntity.Default(contentType(cth), contentLength, transformData(data)) } @@ -308,7 +313,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser transformChunks: Source[HttpEntity.ChunkStreamPart] ⇒ Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = { val chunks = entityChunks.collect { case EntityChunk(chunk) ⇒ chunk - case EntityStreamError(info) ⇒ throw new StreamException(info) + case EntityStreamError(info) ⇒ throw EntityStreamException(info) } HttpEntity.Chunked(contentType(cth), transformChunks(chunks)) } @@ -324,7 +329,10 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser } private[http] object HttpMessageParser { - type CompletionHandling = () ⇒ Option[ParserOutput] + sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup + final case class Trampoline(f: ByteString ⇒ StateResult) extends StateResult + + type CompletionHandling = () ⇒ Option[ErrorOutput] val CompletionOk: CompletionHandling = () ⇒ None val CompletionIsMessageStartError: CompletionHandling = () ⇒ Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start"))) diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala index b5059c5355..792b18675c 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala @@ -11,9 +11,9 @@ import akka.stream.scaladsl.OperationAttributes._ import akka.stream.stage.{ Context, PushPullStage } import akka.stream.scaladsl.Source import akka.util.ByteString -import akka.http.engine.server.OneHundredContinue import akka.http.model.parser.CharacterClasses import akka.http.util.identityFunc +import akka.http.engine.TokenSourceActor import akka.http.model._ import headers._ import StatusCodes._ @@ -27,6 +27,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, _headerParser: HttpHeaderParser, oneHundredContinueRef: () ⇒ Option[ActorRef] = () ⇒ None) extends HttpMessageParser[RequestOutput](_settings, _headerParser) { + import HttpMessageParser._ import settings._ private[this] var method: HttpMethod = _ @@ -105,7 +106,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, uriBytes = input.iterator.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here? uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode) } catch { - case e: IllegalUriException ⇒ throw new ParsingException(BadRequest, e.info) + case IllegalUriException(info) ⇒ throw new ParsingException(BadRequest, info) } uriEnd + 1 } @@ -133,7 +134,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, def onPull(ctx: Context[T]) = { if (!oneHundredContinueSent) { val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable")) - ref ! OneHundredContinue + ref ! TokenSourceActor.Trigger oneHundredContinueSent = true } ctx.pull() diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala index 86629da8b2..2e9eabfde8 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpResponseParser.scala @@ -10,38 +10,41 @@ import akka.stream.scaladsl.Source import akka.util.ByteString import akka.http.model._ import headers._ -import HttpResponseParser.NoMethod import ParserOutput._ /** * INTERNAL API */ -private[http] class HttpResponseParser(_settings: ParserSettings, - _headerParser: HttpHeaderParser, - dequeueRequestMethodForNextResponse: () ⇒ HttpMethod = () ⇒ NoMethod) +private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: HttpHeaderParser) extends HttpMessageParser[ResponseOutput](_settings, _headerParser) { + import HttpMessageParser._ import settings._ - private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod + private[this] var requestMethodForCurrentResponse: Option[HttpMethod] = None private[this] var statusCode: StatusCode = StatusCodes.OK - def createShallowCopy(dequeueRequestMethodForNextResponse: () ⇒ HttpMethod): HttpResponseParser = - new HttpResponseParser(settings, headerParser.createShallowCopy(), dequeueRequestMethodForNextResponse) + def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy()) - override def startNewMessage(input: ByteString, offset: Int): StateResult = { - requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse() - super.startNewMessage(input, offset) - } + def setRequestMethodForNextResponse(requestMethod: HttpMethod): Unit = + if (requestMethodForCurrentResponse.isEmpty) requestMethodForCurrentResponse = Some(requestMethod) - def parseMessage(input: ByteString, offset: Int): StateResult = - if (requestMethodForCurrentResponse ne NoMethod) { + protected def parseMessage(input: ByteString, offset: Int): StateResult = + if (requestMethodForCurrentResponse.isDefined) { var cursor = parseProtocol(input, offset) if (byteChar(input, cursor) == ' ') { cursor = parseStatusCode(input, cursor + 1) cursor = parseReason(input, cursor)() parseHeaderLines(input, cursor) } else badProtocol - } else failMessageStart("Unexpected server response", input.drop(offset).utf8String) + } else { + emit(NeedNextRequestMethod) + done() + } + + override def emit(output: ResponseOutput): Unit = { + if (output == MessageEnd) requestMethodForCurrentResponse = None + super.emit(output) + } def badProtocol = throw new ParsingException("The server-side HTTP version is not supported") @@ -81,10 +84,11 @@ private[http] class HttpResponseParser(_settings: ParserSettings, def finishEmptyResponse() = { emitResponseStart(emptyEntity(cth)) setCompletionHandling(HttpMessageParser.CompletionOk) + emit(MessageEnd) startNewMessage(input, bodyStart) } - if (statusCode.allowsEntity && (requestMethodForCurrentResponse ne HttpMethods.HEAD)) { + if (statusCode.allowsEntity && (requestMethodForCurrentResponse.get != HttpMethods.HEAD)) { teh match { case None ⇒ clh match { case Some(`Content-Length`(contentLength)) ⇒ @@ -95,6 +99,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, val cl = contentLength.toInt emitResponseStart(strictEntity(cth, input, bodyStart, cl)) setCompletionHandling(HttpMessageParser.CompletionOk) + emit(MessageEnd) startNewMessage(input, bodyStart + cl) } else { emitResponseStart(defaultEntity(cth, contentLength)) @@ -128,11 +133,4 @@ private[http] class HttpResponseParser(_settings: ParserSettings, emit(EntityPart(input drop bodyStart)) continue(parseToCloseBody) } -} - -/** - * INTERNAL API - */ -private[http] object HttpResponseParser { - val NoMethod = HttpMethod.custom("NONE", safe = false, idempotent = false, entityAccepted = false) } \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala index 07daa875f9..ea264f715d 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/ParserOutput.scala @@ -21,6 +21,7 @@ private[http] object ParserOutput { sealed trait ResponseOutput extends ParserOutput sealed trait MessageStart extends ParserOutput sealed trait MessageOutput extends RequestOutput with ResponseOutput + sealed trait ErrorOutput extends MessageOutput final case class RequestStart( method: HttpMethod, @@ -44,7 +45,15 @@ private[http] object ParserOutput { final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput - final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput + final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with ErrorOutput - final case class EntityStreamError(info: ErrorInfo) extends MessageOutput + final case class EntityStreamError(info: ErrorInfo) extends ErrorOutput + + //////////// meta messages /////////// + + case object StreamEnd extends MessageOutput + + case object NeedMoreData extends MessageOutput + + case object NeedNextRequestMethod extends ResponseOutput } diff --git a/akka-http-core/src/main/scala/akka/http/util/package.scala b/akka-http-core/src/main/scala/akka/http/util/package.scala index d7808d8ac3..996d3e7347 100644 --- a/akka-http-core/src/main/scala/akka/http/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/util/package.scala @@ -55,31 +55,25 @@ package object util { .flatten(FlattenStrategy.concat) } - private[http] implicit class EnhancedSource[T](val underlying: Source[T]) { - def printEvent(marker: String): Source[T] = - underlying.transform(() ⇒ new PushStage[T, T] { - override def onPush(element: T, ctx: Context[T]): Directive = { - println(s"$marker: $element") - ctx.push(element) - } - override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { - println(s"$marker: Failure $cause") - super.onUpstreamFailure(cause, ctx) - } - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { - println(s"$marker: Terminated") - super.onUpstreamFinish(ctx) - } - }) - - /** - * Drain this stream into a Vector and provide it as a future value. - * - * FIXME: Should be part of akka-streams - */ - def collectAll(implicit materializer: FlowMaterializer): Future[immutable.Seq[T]] = - underlying.fold(Vector.empty[T])(_ :+ _) - } + def printEvent[T](marker: String): Flow[T, T] = + Flow[T].transform(() ⇒ new PushStage[T, T] { + override def onPush(element: T, ctx: Context[T]): Directive = { + println(s"$marker: $element") + ctx.push(element) + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + println(s"$marker: Error $cause") + super.onUpstreamFailure(cause, ctx) + } + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + println(s"$marker: Complete") + super.onUpstreamFinish(ctx) + } + override def onDownstreamFinish(ctx: Context[T]): TerminationDirective = { + println(s"$marker: Cancel") + super.onDownstreamFinish(ctx) + } + }) private[http] implicit class AddFutureAwaitResult[T](future: Future[T]) { /** "Safe" Await.result that doesn't throw away half of the stacktrace */ diff --git a/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala index 6b4fceea45..602b849209 100644 --- a/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala @@ -36,7 +36,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { implicit val materializer = FlowMaterializer() - "The server-side HTTP infrastructure" should { + "The low-level HTTP infrastructure" should { "properly bind a server" in { val (hostname, port) = temporaryServerHostnameAndPort() @@ -70,6 +70,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val (serverIn, serverOut) = acceptConnection() val clientOutSub = clientOut.expectSubscription() + clientOutSub.expectRequest() clientOutSub.sendNext(HttpRequest(uri = "/abc")) val serverInSub = serverIn.expectSubscription() @@ -77,12 +78,20 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc") val serverOutSub = serverOut.expectSubscription() + serverOutSub.expectRequest() serverOutSub.sendNext(HttpResponse(entity = "yeah")) val clientInSub = clientIn.expectSubscription() clientInSub.request(1) val response = clientIn.expectNext() toStrict(response.entity) shouldEqual HttpEntity("yeah") + + clientOutSub.sendComplete() + serverInSub.request(1) // work-around for #16552 + serverIn.expectComplete() + serverOutSub.expectCancellation() + clientInSub.request(1) // work-around for #16552 + clientIn.expectComplete() } "properly complete a chunked request/response cycle" in new TestSetup { @@ -104,6 +113,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { Await.result(chunkStream.grouped(4).runWith(Sink.head), 100.millis) shouldEqual chunks val serverOutSub = serverOut.expectSubscription() + serverOutSub.expectRequest() serverOutSub.sendNext(HttpResponse(206, List(RawHeader("Age", "42")), chunkedEntity)) val clientInSub = clientIn.expectSubscription() @@ -111,8 +121,42 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val HttpResponse(StatusCodes.PartialContent, List(RawHeader("Age", "42"), Server(_), Date(_)), Chunked(`chunkedContentType`, chunkStream2), HttpProtocols.`HTTP/1.1`) = clientIn.expectNext() Await.result(chunkStream2.grouped(1000).runWith(Sink.head), 100.millis) shouldEqual chunks + + clientOutSub.sendComplete() + serverInSub.request(1) // work-around for #16552 + serverIn.expectComplete() + serverOutSub.expectCancellation() + clientInSub.request(1) // work-around for #16552 + clientIn.expectComplete() } + "be able to deal with eager closing of the request stream on the client side" in new TestSetup { + val (clientOut, clientIn) = openNewClientConnection() + val (serverIn, serverOut) = acceptConnection() + + val clientOutSub = clientOut.expectSubscription() + clientOutSub.sendNext(HttpRequest(uri = "/abc")) + clientOutSub.sendComplete() // complete early + + val serverInSub = serverIn.expectSubscription() + serverInSub.request(1) + serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc") + + val serverOutSub = serverOut.expectSubscription() + serverOutSub.expectRequest() + serverOutSub.sendNext(HttpResponse(entity = "yeah")) + + val clientInSub = clientIn.expectSubscription() + clientInSub.request(1) + val response = clientIn.expectNext() + toStrict(response.entity) shouldEqual HttpEntity("yeah") + + serverInSub.request(1) // work-around for #16552 + serverIn.expectComplete() + serverOutSub.expectCancellation() + clientInSub.request(1) // work-around for #16552 + clientIn.expectComplete() + } } override def afterAll() = system.shutdown() diff --git a/akka-http-core/src/test/scala/akka/http/engine/client/HttpClientSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/client/HttpClientSpec.scala new file mode 100644 index 0000000000..510b3b9490 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/client/HttpClientSpec.scala @@ -0,0 +1,330 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.engine.client + +import java.net.InetSocketAddress +import org.scalatest.Inside +import akka.util.ByteString +import akka.event.NoLogging +import akka.stream.FlowMaterializer +import akka.stream.testkit.{ AkkaSpec, StreamTestKit } +import akka.stream.scaladsl._ +import akka.http.model.HttpEntity._ +import akka.http.model.HttpMethods._ +import akka.http.model._ +import akka.http.model.headers._ +import akka.http.util._ + +class HttpClientSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside { + implicit val materializer = FlowMaterializer() + + "The client implementation" should { + + "properly handle a request/response round-trip" which { + + "has a request with empty entity" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.1 200 OK + |Content-Length: 0 + | + |""") + + responsesSub.request(1) + responses.expectNext(HttpResponse()) + + requestsSub.sendComplete() + netOut.expectComplete() + netInSub.sendComplete() + responses.expectComplete() + } + + "has a request with default entity" in new TestSetup { + val probe = StreamTestKit.PublisherProbe[ByteString]() + requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe)))) + expectWireData( + """PUT / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + |Content-Type: application/octet-stream + |Content-Length: 8 + | + |""") + val sub = probe.expectSubscription() + sub.expectRequest(4) + sub.sendNext(ByteString("ABC")) + expectWireData("ABC") + sub.sendNext(ByteString("DEF")) + expectWireData("DEF") + sub.sendNext(ByteString("XY")) + expectWireData("XY") + sub.sendComplete() + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.1 200 OK + |Content-Length: 0 + | + |""") + + responsesSub.request(1) + responses.expectNext(HttpResponse()) + + requestsSub.sendComplete() + netOut.expectComplete() + netInSub.sendComplete() + responses.expectComplete() + } + + "has a response with a default entity" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.1 200 OK + |Transfer-Encoding: chunked + | + |""") + + responsesSub.request(1) + val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext() + ct shouldEqual ContentTypes.`application/octet-stream` + + val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]() + chunks.runWith(Sink(probe)) + val sub = probe.expectSubscription() + + sendWireData("3\nABC\n") + sub.request(1) + probe.expectNext(HttpEntity.Chunk("ABC")) + + sendWireData("4\nDEFX\n") + sub.request(1) + probe.expectNext(HttpEntity.Chunk("DEFX")) + + sendWireData("0\n\n") + sub.request(1) + probe.expectNext(HttpEntity.LastChunk) + probe.expectComplete() + + requestsSub.sendComplete() + netOut.expectComplete() + netInSub.sendComplete() + responses.expectComplete() + } + + "exhibits eager request stream completion" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + requestsSub.sendComplete() + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.1 200 OK + |Content-Length: 0 + | + |""") + + responsesSub.request(1) + responses.expectNext(HttpResponse()) + + netOut.expectComplete() + netInSub.sendComplete() + responses.expectComplete() + } + } + + "produce proper errors" which { + + "catch the entity stream being shorter than the Content-Length" in new TestSetup { + val probe = StreamTestKit.PublisherProbe[ByteString]() + requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe)))) + expectWireData( + """PUT / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + |Content-Type: application/octet-stream + |Content-Length: 8 + | + |""") + val sub = probe.expectSubscription() + sub.expectRequest(4) + sub.sendNext(ByteString("ABC")) + expectWireData("ABC") + sub.sendNext(ByteString("DEF")) + expectWireData("DEF") + sub.sendComplete() + + val InvalidContentLengthException(info) = netOut.expectError() + info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to 2 bytes less" + netInSub.sendComplete() + responses.expectComplete() + netInSub.expectCancellation() + } + + "catch the entity stream being longer than the Content-Length" in new TestSetup { + val probe = StreamTestKit.PublisherProbe[ByteString]() + requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe)))) + expectWireData( + """PUT / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + |Content-Type: application/octet-stream + |Content-Length: 8 + | + |""") + val sub = probe.expectSubscription() + sub.expectRequest(4) + sub.sendNext(ByteString("ABC")) + expectWireData("ABC") + sub.sendNext(ByteString("DEF")) + expectWireData("DEF") + sub.sendNext(ByteString("XYZ")) + + val InvalidContentLengthException(info) = netOut.expectError() + info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to more bytes" + netInSub.sendComplete() + responses.expectComplete() + netInSub.expectCancellation() + } + + "catch illegal response starts" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.2 200 OK + | + |""") + + val error @ IllegalResponseException(info) = responses.expectError() + info.summary shouldEqual "The server-side HTTP version is not supported" + netOut.expectError(error) + requestsSub.expectCancellation() + } + + "catch illegal response chunks" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData( + """HTTP/1.1 200 OK + |Transfer-Encoding: chunked + | + |""") + + responsesSub.request(1) + val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext() + ct shouldEqual ContentTypes.`application/octet-stream` + + val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]() + chunks.runWith(Sink(probe)) + val sub = probe.expectSubscription() + + sendWireData("3\nABC\n") + sub.request(1) + probe.expectNext(HttpEntity.Chunk("ABC")) + + sendWireData("4\nDEFXX") + sub.request(1) + val error @ EntityStreamException(info) = probe.expectError() + info.summary shouldEqual "Illegal chunk termination" + + responses.expectComplete() + netOut.expectComplete() + requestsSub.expectCancellation() + } + + "catch a response start truncation" in new TestSetup { + requestsSub.sendNext(HttpRequest()) + expectWireData( + """GET / HTTP/1.1 + |Host: example.com:80 + |User-Agent: akka-http/test + | + |""") + + netInSub.expectRequest(16) + sendWireData("HTTP/1.1 200 OK") + netInSub.sendComplete() + + val error @ IllegalResponseException(info) = responses.expectError() + info.summary shouldEqual "Illegal HTTP message start" + netOut.expectError(error) + requestsSub.expectCancellation() + } + } + } + + class TestSetup { + val requests = StreamTestKit.PublisherProbe[HttpRequest] + val responses = StreamTestKit.SubscriberProbe[HttpResponse] + val remoteAddress = new InetSocketAddress("example.com", 80) + + def settings = ClientConnectionSettings(system) + .copy(userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test"))))) + + val (netOut, netIn) = { + val netOut = StreamTestKit.SubscriberProbe[ByteString] + val netIn = StreamTestKit.PublisherProbe[ByteString] + val clientFlow = HttpClient.transportToConnectionClientFlow( + Flow(Sink(netOut), Source(netIn)), remoteAddress, settings, NoLogging) + Source(requests).via(clientFlow).runWith(Sink(responses)) + netOut -> netIn + } + + def wipeDate(string: String) = + string.fastSplit('\n').map { + case s if s.startsWith("Date:") ⇒ "Date: XXXX\r" + case s ⇒ s + }.mkString("\n") + + val netInSub = netIn.expectSubscription() + val netOutSub = netOut.expectSubscription() + val requestsSub = requests.expectSubscription() + val responsesSub = responses.expectSubscription() + + def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII")) + def sendWireData(data: ByteString): Unit = netInSub.sendNext(data) + + def expectWireData(s: String) = { + netOutSub.request(1) + netOut.expectNext().utf8String shouldEqual s.stripMarginWithNewline("\r\n") + } + + def closeNetworkInput(): Unit = netInSub.sendComplete() + } +} \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala index 377db41b8d..ad92cd0ea1 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/parsing/RequestParserSpec.scala @@ -120,8 +120,9 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { |Host: x | |ABCDPATCH""" - }.toCharArray.map(_.toString).toSeq should rawMultiParseTo( - HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes)) + }.toCharArray.map(_.toString).toSeq should generalRawMultiParseTo( + Right(HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes)), + Left(MessageStartError(400, ErrorInfo("Illegal HTTP message start")))) closeAfterResponseCompletion shouldEqual Seq(false) } @@ -232,7 +233,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { val parser = newParser val result = multiParse(newParser)(Seq(prep(start + manyChunks))) val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity - val strictChunks = chunks.collectAll.awaitResult(awaitAtMost) + val strictChunks = chunks.grouped(100000).runWith(Sink.head).awaitResult(awaitAtMost) strictChunks.size shouldEqual numChunks } } @@ -442,7 +443,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] = Source(input.toList) .map(ByteString.apply) - .section(name("parser"))(_.transform(() ⇒ parser)) + .section(name("parser"))(_.transform(() ⇒ parser.stage)) .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .headAndTail .collect { @@ -461,7 +462,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } .flatten(FlattenStrategy.concat) .map(strictEqualify) - .collectAll + .grouped(100000).runWith(Sink.head) .awaitResult(awaitAtMost) protected def parserSettings: ParserSettings = ParserSettings(system) @@ -474,7 +475,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] = - data.collectAll + data.grouped(100000).runWith(Sink.head) .fast.recover { case _: NoSuchElementException ⇒ Nil } def prep(response: String) = response.stripMarginWithNewline("\r\n") diff --git a/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala index 305e36422f..eb067d749d 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/parsing/ResponseParserSpec.scala @@ -261,7 +261,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { val future = Source(input.toList) .map(ByteString.apply) - .section(name("parser"))(_.transform(() ⇒ newParser(requestMethod))) + .section(name("parser"))(_.transform(() ⇒ newParserStage(requestMethod))) .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .headAndTail .collect { @@ -279,14 +279,16 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } .flatten(FlattenStrategy.concat) .map(strictEqualify) - .grouped(1000).runWith(Sink.head) + .grouped(100000).runWith(Sink.head) Await.result(future, 500.millis) } def parserSettings: ParserSettings = ParserSettings(system) - def newParser(requestMethod: HttpMethod = GET) = { - val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)(), () ⇒ requestMethod) - parser + + def newParserStage(requestMethod: HttpMethod = GET) = { + val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)()) + parser.setRequestMethodForNextResponse(requestMethod) + parser.stage } private def compactEntity(entity: ResponseEntity): Future[ResponseEntity] = @@ -296,7 +298,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Source[ChunkStreamPart]] = - data.grouped(1000).runWith(Sink.head) + data.grouped(100000).runWith(Sink.head) .fast.map(source(_: _*)) .fast.recover { case _: NoSuchElementException ⇒ source() } diff --git a/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerSpec.scala index b94425def2..d189628e50 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/server/HttpServerSpec.scala @@ -5,7 +5,7 @@ package akka.http.engine.server import scala.concurrent.duration._ -import org.scalatest.{ Inside, BeforeAndAfterAll, Matchers } +import org.scalatest.Inside import akka.event.NoLogging import akka.util.ByteString import akka.stream.scaladsl._ @@ -18,7 +18,7 @@ import HttpEntity._ import MediaTypes._ import HttpMethods._ -class HttpServerSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside { +class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside { implicit val materializer = FlowMaterializer() "The server implementation" should { diff --git a/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala b/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala index 373c864821..c7637f4dcb 100644 --- a/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala +++ b/akka-http-testkit/src/main/scala/akka/http/testkit/RouteTestResultComponent.scala @@ -96,6 +96,6 @@ trait RouteTestResultComponent { failTest("Request was neither completed nor rejected within " + timeout) private def awaitAllElements[T](data: Source[T]): immutable.Seq[T] = - data.collectAll.awaitResult(timeout) + data.grouped(100000).runWith(Sink.head).awaitResult(timeout) } } \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala index 8401186b83..2a10f0a06b 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/RangeDirectivesSpec.scala @@ -5,17 +5,16 @@ package akka.http.server package directives +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.http.model.StatusCodes._ import akka.http.model._ import akka.http.model.headers._ import akka.http.util._ -import akka.stream.scaladsl.Source +import akka.stream.scaladsl.{ Sink, Source } import akka.util.ByteString import org.scalatest.{ Inside, Inspectors } -import scala.concurrent.Await -import scala.concurrent.duration._ - class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { lazy val wrs = mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) & @@ -100,7 +99,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { wrs { complete("Some random and not super short entity.") } } ~> check { header[`Content-Range`] should be(None) - val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second) + val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second) parts.size shouldEqual 2 inside(parts(0)) { case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) ⇒ @@ -125,7 +124,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { wrs { complete(HttpEntity.Default(MediaTypes.`text/plain`, content.length, entityData())) } } ~> check { header[`Content-Range`] should be(None) - val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second) + val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second) parts.size shouldEqual 2 } }