=htc refactor HttpClient stream setup, closes #16510

This commit is contained in:
Mathias 2014-12-18 17:05:50 +01:00
parent 4d907bfb50
commit 968e9cc5a7
13 changed files with 687 additions and 127 deletions

View file

@ -5,16 +5,17 @@
package akka.http.engine.client package akka.http.engine.client
import java.net.InetSocketAddress import java.net.InetSocketAddress
import scala.collection.immutable.Queue import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import akka.stream.stage._
import akka.util.ByteString import akka.util.ByteString
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.stream.FlattenStrategy import akka.stream.FlattenStrategy
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.scaladsl.OperationAttributes._ import akka.stream.scaladsl.OperationAttributes._
import akka.http.model.{ HttpMethod, HttpRequest, HttpResponse } import akka.http.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse }
import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory } import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
import akka.http.engine.parsing.{ HttpHeaderParser, HttpResponseParser } import akka.http.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
import akka.http.engine.parsing.ParserOutput._
import akka.http.util._ import akka.http.util._
/** /**
@ -37,39 +38,212 @@ private[http] object HttpClient {
}) })
val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log) val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
val requestMethodByPass = new RequestMethodByPass(remoteAddress)
Flow[HttpRequest] /*
.map(requestMethodByPass) Basic Stream Setup
==================
requestIn +----------+
+-----------------------------------------------+--->| Termi- | requestRendering
| | nation +---------------------> |
+-------------------------------------->| Merge | |
| Termination Backchannel | +----------+ | TCP-
| | | level
| | Method | client
| +------------+ | Bypass | flow
responseOut | responsePrep | Response |<---+ |
<------------+----------------| Parsing | |
| Merge |<------------------------------------------ V
+------------+
*/
val requestIn = UndefinedSource[HttpRequest]
val responseOut = UndefinedSink[HttpResponse]
val methodBypassFanout = Broadcast[HttpRequest]
val responseParsingMerge = new ResponseParsingMerge(rootParser)
val terminationFanout = Broadcast[HttpResponse]
val terminationMerge = new TerminationMerge
val requestRendering = Flow[HttpRequest]
.map(RequestRenderingContext(_, remoteAddress))
.section(name("renderer"))(_.transform(() requestRendererFactory.newRenderer)) .section(name("renderer"))(_.transform(() requestRendererFactory.newRenderer))
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
val transportFlow = Flow[ByteString]
.section(name("errorLogger"))(_.transform(() errorLogger(log, "Outgoing request stream error"))) .section(name("errorLogger"))(_.transform(() errorLogger(log, "Outgoing request stream error")))
.via(transport) .via(transport)
.section(name("rootParser"))(_.transform(()
// each connection uses a single (private) response parser instance for all its responses val methodBypass = Flow[HttpRequest].map(_.method)
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy(requestMethodByPass))) import ParserOutput._
.splitWhen(_.isInstanceOf[MessageStart]) val responsePrep = Flow[List[ResponseOutput]]
.transform(recover { case x: ResponseParsingError x.error :: Nil }) // FIXME after #16565
.mapConcat(identityFunc)
.splitWhen(x x.isInstanceOf[MessageStart] || x == MessageEnd)
.headAndTail .headAndTail
.collect { .collect {
case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts) case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts)
HttpResponse(statusCode, headers, createEntity(entityParts), protocol) HttpResponse(statusCode, headers, createEntity(entityParts), protocol)
case (MessageStartError(_, info), _) throw IllegalResponseException(info)
} }
import FlowGraphImplicits._
Flow() { implicit b
requestIn ~> methodBypassFanout ~> terminationMerge.requestInput ~> requestRendering ~> transportFlow ~>
responseParsingMerge.dataInput ~> responsePrep ~> terminationFanout ~> responseOut
methodBypassFanout ~> methodBypass ~> responseParsingMerge.methodBypassInput
terminationFanout ~> terminationMerge.terminationBackchannelInput
b.allowCycles()
requestIn -> responseOut
}
} }
// FIXME: refactor to a pure-stream design that allows us to get rid of this ad-hoc queue here // a simple merge stage that simply forwards its first input and ignores its second input
class RequestMethodByPass(serverAddress: InetSocketAddress) // (the terminationBackchannelInput), but applies a special completion handling
extends (HttpRequest RequestRenderingContext) with (() HttpMethod) { class TerminationMerge extends FlexiMerge[HttpRequest] {
private[this] var requestMethods = Queue.empty[HttpMethod] import FlexiMerge._
def apply(request: HttpRequest) = { val requestInput = createInputPort[HttpRequest]()
requestMethods = requestMethods.enqueue(request.method) val terminationBackchannelInput = createInputPort[HttpResponse]()
RequestRenderingContext(request, serverAddress)
def createMergeLogic() = new MergeLogic[HttpRequest] {
override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"TerminationMerge must have 2 connected inputs, was $inputCount")
Vector(requestInput, terminationBackchannelInput)
}
override def initialState = State[Any](ReadAny(requestInput, terminationBackchannelInput)) {
case (ctx, _, request: HttpRequest) { ctx.emit(request); SameState }
case _ SameState // simply drop all responses, we are only interested in the completion of the response input
}
override def initialCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `requestInput`) SameState
case (ctx, `terminationBackchannelInput`)
ctx.complete()
SameState
},
onError = defaultCompletionHandling.onError)
} }
def apply(): HttpMethod = }
if (requestMethods.nonEmpty) {
val method = requestMethods.head import ParserOutput._
requestMethods = requestMethods.tail
method /**
} else HttpResponseParser.NoMethod * A FlexiMerge that follows this logic:
* 1. Wait on the methodBypass for the method of the request corresponding to the next response to be received
* 2. Read from the dataInput until exactly one response has been fully received
* 3. Go back to 1.
*/
class ResponseParsingMerge(rootParser: HttpResponseParser) extends FlexiMerge[List[ResponseOutput]] {
import FlexiMerge._
val dataInput = createInputPort[ByteString]()
val methodBypassInput = createInputPort[HttpMethod]()
def createMergeLogic() = new MergeLogic[List[ResponseOutput]] {
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
val parser = rootParser.createShallowCopy()
var methodBypassCompleted = false
override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"ResponseParsingMerge must have 2 connected inputs, was $inputCount")
Vector(dataInput, methodBypassInput)
}
override val initialState: State[HttpMethod] =
State(Read(methodBypassInput)) {
case (ctx, _, method)
parser.setRequestMethodForNextResponse(method)
drainParser(parser.onPush(ByteString.empty), ctx,
onNeedNextMethod = () SameState,
onNeedMoreData = () {
ctx.changeCompletionHandling(responseReadingCompletionHandling)
responseReadingState
})
}
val responseReadingState: State[ByteString] =
State(Read(dataInput)) {
case (ctx, _, bytes)
drainParser(parser.onPush(bytes), ctx,
onNeedNextMethod = () {
if (methodBypassCompleted) {
ctx.complete()
SameState
} else {
ctx.changeCompletionHandling(initialCompletionHandling)
initialState
}
},
onNeedMoreData = () SameState)
}
@tailrec def drainParser(current: ResponseOutput, ctx: MergeLogicContext,
onNeedNextMethod: () State[_], onNeedMoreData: () State[_],
b: ListBuffer[ResponseOutput] = ListBuffer.empty): State[_] = {
def emit(output: List[ResponseOutput]): Unit = if (output.nonEmpty) ctx.emit(output)
current match {
case NeedNextRequestMethod
emit(b.result())
onNeedNextMethod()
case StreamEnd
emit(b.result())
ctx.complete()
SameState
case NeedMoreData
emit(b.result())
onNeedMoreData()
case x drainParser(parser.onPull(), ctx, onNeedNextMethod, onNeedMoreData, b += x)
}
}
override val initialCompletionHandling = CompletionHandling(
onComplete = (ctx, _) { ctx.complete(); SameState },
onError = defaultCompletionHandling.onError)
val responseReadingCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `methodBypassInput`)
methodBypassCompleted = true
SameState
case (ctx, `dataInput`)
if (parser.onUpstreamFinish()) {
ctx.complete()
} else {
// not pretty but because the FlexiMerge doesn't let us emit from here (#16565)
// we need to funnel the error through the error channel
ctx.error(new ResponseParsingError(parser.onPull().asInstanceOf[ErrorOutput]))
}
SameState
},
onError = defaultCompletionHandling.onError)
}
}
private class ResponseParsingError(val error: ErrorOutput) extends RuntimeException
// TODO: remove after #16394 is cleared
def recover[A, B >: A](pf: PartialFunction[Throwable, B]): () PushPullStage[A, B] = {
val stage = new PushPullStage[A, B] {
var recovery: Option[B] = None
def onPush(elem: A, ctx: Context[B]): Directive = ctx.push(elem)
def onPull(ctx: Context[B]): Directive = recovery match {
case None ctx.pull()
case Some(x) { recovery = null; ctx.push(x) }
case null ctx.finish()
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[B]): TerminationDirective =
if (pf isDefinedAt cause) {
recovery = Some(pf(cause))
ctx.absorbTermination()
} else super.onUpstreamFailure(cause, ctx)
}
() stage
} }
} }

View file

@ -10,7 +10,6 @@ import akka.parboiled2.CharUtils
import akka.util.ByteString import akka.util.ByteString
import akka.stream.scaladsl.Source import akka.stream.scaladsl.Source
import akka.stream.stage._ import akka.stream.stage._
import akka.http.Http.StreamException
import akka.http.model.parser.CharacterClasses import akka.http.model.parser.CharacterClasses
import akka.http.util._ import akka.http.util._
import akka.http.model._ import akka.http.model._
@ -22,21 +21,33 @@ import ParserOutput._
* INTERNAL API * INTERNAL API
*/ */
private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings, private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings,
val headerParser: HttpHeaderParser) val headerParser: HttpHeaderParser) { self
extends PushPullStage[ByteString, Output] {
import HttpMessageParser._ import HttpMessageParser._
import settings._ import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
private[this] val result = new ListBuffer[Output] private[this] val result = new ListBuffer[Output]
private[this] var state: ByteString StateResult = startNewMessage(_, 0) private[this] var state: ByteString StateResult = startNewMessage(_, 0)
private[this] var protocol: HttpProtocol = `HTTP/1.1` private[this] var protocol: HttpProtocol = `HTTP/1.1`
private[this] var completionHandling: CompletionHandling = CompletionOk private[this] var completionHandling: CompletionHandling = CompletionOk
private[this] var terminated = false private[this] var terminated = false
override def onPush(input: ByteString, ctx: Context[Output]): Directive = { def isTerminated = terminated
val stage: PushPullStage[ByteString, Output] =
new PushPullStage[ByteString, Output] {
def onPush(elem: ByteString, ctx: Context[Output]) = handleParserOutput(self.onPush(elem), ctx)
def onPull(ctx: Context[Output]) = handleParserOutput(self.onPull(), ctx)
private def handleParserOutput(output: Output, ctx: Context[Output]): Directive =
output match {
case StreamEnd ctx.finish()
case NeedMoreData ctx.pull()
case x ctx.push(x)
}
override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective =
if (self.onUpstreamFinish()) ctx.finish() else ctx.absorbTermination()
}
final def onPush(input: ByteString): Output = {
@tailrec def run(next: ByteString StateResult): StateResult = @tailrec def run(next: ByteString StateResult): StateResult =
(try next(input) (try next(input)
catch { catch {
@ -51,37 +62,32 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`") if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`")
run(state) run(state)
pushResultHeadAndFinishOrPull(ctx) onPull()
} }
def onPull(ctx: Context[Output]): Directive = pushResultHeadAndFinishOrPull(ctx) final def onPull(): Output =
def pushResultHeadAndFinishOrPull(ctx: Context[Output]): Directive =
if (result.nonEmpty) { if (result.nonEmpty) {
val head = result.head val head = result.head
result.remove(0) // faster than `ListBuffer::drop` result.remove(0) // faster than `ListBuffer::drop`
ctx.push(head) head
} else if (terminated) ctx.finish() else ctx.pull() } else if (terminated) StreamEnd else NeedMoreData
override def onUpstreamFinish(ctx: Context[Output]) = { final def onUpstreamFinish(): Boolean = {
completionHandling() match { completionHandling() match {
case Some(x) emit(x.asInstanceOf[Output]) case Some(x) emit(x)
case None // nothing to do case None // nothing to do
} }
terminated = true terminated = true
if (result.isEmpty) ctx.finish() else ctx.absorbTermination() result.isEmpty
} }
def startNewMessage(input: ByteString, offset: Int): StateResult = { protected final def startNewMessage(input: ByteString, offset: Int): StateResult = {
def _startNewMessage(input: ByteString, offset: Int): StateResult =
try parseMessage(input, offset)
catch { case NotEnoughDataException continue(input, offset)(_startNewMessage) }
if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError) if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError)
_startNewMessage(input, offset) try parseMessage(input, offset)
catch { case NotEnoughDataException continue(input, offset)(startNewMessage) }
} }
def parseMessage(input: ByteString, offset: Int): StateResult protected def parseMessage(input: ByteString, offset: Int): StateResult
def parseProtocol(input: ByteString, cursor: Int): Int = { def parseProtocol(input: ByteString, cursor: Int): Int = {
def c(ix: Int) = byteChar(input, cursor + ix) def c(ix: Int) = byteChar(input, cursor + ix)
@ -204,7 +210,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
val chunkBodyEnd = cursor + chunkSize val chunkBodyEnd = cursor + chunkSize
def result(terminatorLen: Int) = { def result(terminatorLen: Int) = {
emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension))) emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage)) Trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
} }
byteChar(input, chunkBodyEnd) match { byteChar(input, chunkBodyEnd) match {
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2) case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2)
@ -255,7 +261,6 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
state = next(_, 0) state = next(_, 0)
done() done()
} }
def trampoline(next: ByteString StateResult): StateResult = Trampoline(next)
def failMessageStart(summary: String): StateResult = failMessageStart(summary, "") def failMessageStart(summary: String): StateResult = failMessageStart(summary, "")
def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail) def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail)
@ -299,7 +304,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
transformData: Source[ByteString] Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = { transformData: Source[ByteString] Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = {
val data = entityParts.collect { val data = entityParts.collect {
case EntityPart(bytes) bytes case EntityPart(bytes) bytes
case EntityStreamError(info) throw new StreamException(info) case EntityStreamError(info) throw EntityStreamException(info)
} }
HttpEntity.Default(contentType(cth), contentLength, transformData(data)) HttpEntity.Default(contentType(cth), contentLength, transformData(data))
} }
@ -308,7 +313,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
transformChunks: Source[HttpEntity.ChunkStreamPart] Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = { transformChunks: Source[HttpEntity.ChunkStreamPart] Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = {
val chunks = entityChunks.collect { val chunks = entityChunks.collect {
case EntityChunk(chunk) chunk case EntityChunk(chunk) chunk
case EntityStreamError(info) throw new StreamException(info) case EntityStreamError(info) throw EntityStreamException(info)
} }
HttpEntity.Chunked(contentType(cth), transformChunks(chunks)) HttpEntity.Chunked(contentType(cth), transformChunks(chunks))
} }
@ -324,7 +329,10 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
} }
private[http] object HttpMessageParser { private[http] object HttpMessageParser {
type CompletionHandling = () Option[ParserOutput] sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
type CompletionHandling = () Option[ErrorOutput]
val CompletionOk: CompletionHandling = () None val CompletionOk: CompletionHandling = () None
val CompletionIsMessageStartError: CompletionHandling = val CompletionIsMessageStartError: CompletionHandling =
() Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start"))) () Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start")))

View file

@ -11,9 +11,9 @@ import akka.stream.scaladsl.OperationAttributes._
import akka.stream.stage.{ Context, PushPullStage } import akka.stream.stage.{ Context, PushPullStage }
import akka.stream.scaladsl.Source import akka.stream.scaladsl.Source
import akka.util.ByteString import akka.util.ByteString
import akka.http.engine.server.OneHundredContinue
import akka.http.model.parser.CharacterClasses import akka.http.model.parser.CharacterClasses
import akka.http.util.identityFunc import akka.http.util.identityFunc
import akka.http.engine.TokenSourceActor
import akka.http.model._ import akka.http.model._
import headers._ import headers._
import StatusCodes._ import StatusCodes._
@ -27,6 +27,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
_headerParser: HttpHeaderParser, _headerParser: HttpHeaderParser,
oneHundredContinueRef: () Option[ActorRef] = () None) oneHundredContinueRef: () Option[ActorRef] = () None)
extends HttpMessageParser[RequestOutput](_settings, _headerParser) { extends HttpMessageParser[RequestOutput](_settings, _headerParser) {
import HttpMessageParser._
import settings._ import settings._
private[this] var method: HttpMethod = _ private[this] var method: HttpMethod = _
@ -105,7 +106,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
uriBytes = input.iterator.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here? uriBytes = input.iterator.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here?
uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode) uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode)
} catch { } catch {
case e: IllegalUriException throw new ParsingException(BadRequest, e.info) case IllegalUriException(info) throw new ParsingException(BadRequest, info)
} }
uriEnd + 1 uriEnd + 1
} }
@ -133,7 +134,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
def onPull(ctx: Context[T]) = { def onPull(ctx: Context[T]) = {
if (!oneHundredContinueSent) { if (!oneHundredContinueSent) {
val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable")) val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable"))
ref ! OneHundredContinue ref ! TokenSourceActor.Trigger
oneHundredContinueSent = true oneHundredContinueSent = true
} }
ctx.pull() ctx.pull()

View file

@ -10,38 +10,41 @@ import akka.stream.scaladsl.Source
import akka.util.ByteString import akka.util.ByteString
import akka.http.model._ import akka.http.model._
import headers._ import headers._
import HttpResponseParser.NoMethod
import ParserOutput._ import ParserOutput._
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] class HttpResponseParser(_settings: ParserSettings, private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: HttpHeaderParser)
_headerParser: HttpHeaderParser,
dequeueRequestMethodForNextResponse: () HttpMethod = () NoMethod)
extends HttpMessageParser[ResponseOutput](_settings, _headerParser) { extends HttpMessageParser[ResponseOutput](_settings, _headerParser) {
import HttpMessageParser._
import settings._ import settings._
private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod private[this] var requestMethodForCurrentResponse: Option[HttpMethod] = None
private[this] var statusCode: StatusCode = StatusCodes.OK private[this] var statusCode: StatusCode = StatusCodes.OK
def createShallowCopy(dequeueRequestMethodForNextResponse: () HttpMethod): HttpResponseParser = def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy())
new HttpResponseParser(settings, headerParser.createShallowCopy(), dequeueRequestMethodForNextResponse)
override def startNewMessage(input: ByteString, offset: Int): StateResult = { def setRequestMethodForNextResponse(requestMethod: HttpMethod): Unit =
requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse() if (requestMethodForCurrentResponse.isEmpty) requestMethodForCurrentResponse = Some(requestMethod)
super.startNewMessage(input, offset)
}
def parseMessage(input: ByteString, offset: Int): StateResult = protected def parseMessage(input: ByteString, offset: Int): StateResult =
if (requestMethodForCurrentResponse ne NoMethod) { if (requestMethodForCurrentResponse.isDefined) {
var cursor = parseProtocol(input, offset) var cursor = parseProtocol(input, offset)
if (byteChar(input, cursor) == ' ') { if (byteChar(input, cursor) == ' ') {
cursor = parseStatusCode(input, cursor + 1) cursor = parseStatusCode(input, cursor + 1)
cursor = parseReason(input, cursor)() cursor = parseReason(input, cursor)()
parseHeaderLines(input, cursor) parseHeaderLines(input, cursor)
} else badProtocol } else badProtocol
} else failMessageStart("Unexpected server response", input.drop(offset).utf8String) } else {
emit(NeedNextRequestMethod)
done()
}
override def emit(output: ResponseOutput): Unit = {
if (output == MessageEnd) requestMethodForCurrentResponse = None
super.emit(output)
}
def badProtocol = throw new ParsingException("The server-side HTTP version is not supported") def badProtocol = throw new ParsingException("The server-side HTTP version is not supported")
@ -81,10 +84,11 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
def finishEmptyResponse() = { def finishEmptyResponse() = {
emitResponseStart(emptyEntity(cth)) emitResponseStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk) setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart) startNewMessage(input, bodyStart)
} }
if (statusCode.allowsEntity && (requestMethodForCurrentResponse ne HttpMethods.HEAD)) { if (statusCode.allowsEntity && (requestMethodForCurrentResponse.get != HttpMethods.HEAD)) {
teh match { teh match {
case None clh match { case None clh match {
case Some(`Content-Length`(contentLength)) case Some(`Content-Length`(contentLength))
@ -95,6 +99,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
val cl = contentLength.toInt val cl = contentLength.toInt
emitResponseStart(strictEntity(cth, input, bodyStart, cl)) emitResponseStart(strictEntity(cth, input, bodyStart, cl))
setCompletionHandling(HttpMessageParser.CompletionOk) setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart + cl) startNewMessage(input, bodyStart + cl)
} else { } else {
emitResponseStart(defaultEntity(cth, contentLength)) emitResponseStart(defaultEntity(cth, contentLength))
@ -128,11 +133,4 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
emit(EntityPart(input drop bodyStart)) emit(EntityPart(input drop bodyStart))
continue(parseToCloseBody) continue(parseToCloseBody)
} }
}
/**
* INTERNAL API
*/
private[http] object HttpResponseParser {
val NoMethod = HttpMethod.custom("NONE", safe = false, idempotent = false, entityAccepted = false)
} }

View file

@ -21,6 +21,7 @@ private[http] object ParserOutput {
sealed trait ResponseOutput extends ParserOutput sealed trait ResponseOutput extends ParserOutput
sealed trait MessageStart extends ParserOutput sealed trait MessageStart extends ParserOutput
sealed trait MessageOutput extends RequestOutput with ResponseOutput sealed trait MessageOutput extends RequestOutput with ResponseOutput
sealed trait ErrorOutput extends MessageOutput
final case class RequestStart( final case class RequestStart(
method: HttpMethod, method: HttpMethod,
@ -44,7 +45,15 @@ private[http] object ParserOutput {
final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput
final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with ErrorOutput
final case class EntityStreamError(info: ErrorInfo) extends MessageOutput final case class EntityStreamError(info: ErrorInfo) extends ErrorOutput
//////////// meta messages ///////////
case object StreamEnd extends MessageOutput
case object NeedMoreData extends MessageOutput
case object NeedNextRequestMethod extends ResponseOutput
} }

View file

@ -55,31 +55,25 @@ package object util {
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
} }
private[http] implicit class EnhancedSource[T](val underlying: Source[T]) { def printEvent[T](marker: String): Flow[T, T] =
def printEvent(marker: String): Source[T] = Flow[T].transform(() new PushStage[T, T] {
underlying.transform(() new PushStage[T, T] { override def onPush(element: T, ctx: Context[T]): Directive = {
override def onPush(element: T, ctx: Context[T]): Directive = { println(s"$marker: $element")
println(s"$marker: $element") ctx.push(element)
ctx.push(element) }
} override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { println(s"$marker: Error $cause")
println(s"$marker: Failure $cause") super.onUpstreamFailure(cause, ctx)
super.onUpstreamFailure(cause, ctx) }
} override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = {
override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { println(s"$marker: Complete")
println(s"$marker: Terminated") super.onUpstreamFinish(ctx)
super.onUpstreamFinish(ctx) }
} override def onDownstreamFinish(ctx: Context[T]): TerminationDirective = {
}) println(s"$marker: Cancel")
super.onDownstreamFinish(ctx)
/** }
* Drain this stream into a Vector and provide it as a future value. })
*
* FIXME: Should be part of akka-streams
*/
def collectAll(implicit materializer: FlowMaterializer): Future[immutable.Seq[T]] =
underlying.fold(Vector.empty[T])(_ :+ _)
}
private[http] implicit class AddFutureAwaitResult[T](future: Future[T]) { private[http] implicit class AddFutureAwaitResult[T](future: Future[T]) {
/** "Safe" Await.result that doesn't throw away half of the stacktrace */ /** "Safe" Await.result that doesn't throw away half of the stacktrace */

View file

@ -36,7 +36,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
implicit val materializer = FlowMaterializer() implicit val materializer = FlowMaterializer()
"The server-side HTTP infrastructure" should { "The low-level HTTP infrastructure" should {
"properly bind a server" in { "properly bind a server" in {
val (hostname, port) = temporaryServerHostnameAndPort() val (hostname, port) = temporaryServerHostnameAndPort()
@ -70,6 +70,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
val (serverIn, serverOut) = acceptConnection() val (serverIn, serverOut) = acceptConnection()
val clientOutSub = clientOut.expectSubscription() val clientOutSub = clientOut.expectSubscription()
clientOutSub.expectRequest()
clientOutSub.sendNext(HttpRequest(uri = "/abc")) clientOutSub.sendNext(HttpRequest(uri = "/abc"))
val serverInSub = serverIn.expectSubscription() val serverInSub = serverIn.expectSubscription()
@ -77,12 +78,20 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc") serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc")
val serverOutSub = serverOut.expectSubscription() val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(entity = "yeah")) serverOutSub.sendNext(HttpResponse(entity = "yeah"))
val clientInSub = clientIn.expectSubscription() val clientInSub = clientIn.expectSubscription()
clientInSub.request(1) clientInSub.request(1)
val response = clientIn.expectNext() val response = clientIn.expectNext()
toStrict(response.entity) shouldEqual HttpEntity("yeah") toStrict(response.entity) shouldEqual HttpEntity("yeah")
clientOutSub.sendComplete()
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
} }
"properly complete a chunked request/response cycle" in new TestSetup { "properly complete a chunked request/response cycle" in new TestSetup {
@ -104,6 +113,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
Await.result(chunkStream.grouped(4).runWith(Sink.head), 100.millis) shouldEqual chunks Await.result(chunkStream.grouped(4).runWith(Sink.head), 100.millis) shouldEqual chunks
val serverOutSub = serverOut.expectSubscription() val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(206, List(RawHeader("Age", "42")), chunkedEntity)) serverOutSub.sendNext(HttpResponse(206, List(RawHeader("Age", "42")), chunkedEntity))
val clientInSub = clientIn.expectSubscription() val clientInSub = clientIn.expectSubscription()
@ -111,8 +121,42 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
val HttpResponse(StatusCodes.PartialContent, List(RawHeader("Age", "42"), Server(_), Date(_)), val HttpResponse(StatusCodes.PartialContent, List(RawHeader("Age", "42"), Server(_), Date(_)),
Chunked(`chunkedContentType`, chunkStream2), HttpProtocols.`HTTP/1.1`) = clientIn.expectNext() Chunked(`chunkedContentType`, chunkStream2), HttpProtocols.`HTTP/1.1`) = clientIn.expectNext()
Await.result(chunkStream2.grouped(1000).runWith(Sink.head), 100.millis) shouldEqual chunks Await.result(chunkStream2.grouped(1000).runWith(Sink.head), 100.millis) shouldEqual chunks
clientOutSub.sendComplete()
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
} }
"be able to deal with eager closing of the request stream on the client side" in new TestSetup {
val (clientOut, clientIn) = openNewClientConnection()
val (serverIn, serverOut) = acceptConnection()
val clientOutSub = clientOut.expectSubscription()
clientOutSub.sendNext(HttpRequest(uri = "/abc"))
clientOutSub.sendComplete() // complete early
val serverInSub = serverIn.expectSubscription()
serverInSub.request(1)
serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc")
val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(entity = "yeah"))
val clientInSub = clientIn.expectSubscription()
clientInSub.request(1)
val response = clientIn.expectNext()
toStrict(response.entity) shouldEqual HttpEntity("yeah")
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
}
} }
override def afterAll() = system.shutdown() override def afterAll() = system.shutdown()

View file

@ -0,0 +1,330 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.client
import java.net.InetSocketAddress
import org.scalatest.Inside
import akka.util.ByteString
import akka.event.NoLogging
import akka.stream.FlowMaterializer
import akka.stream.testkit.{ AkkaSpec, StreamTestKit }
import akka.stream.scaladsl._
import akka.http.model.HttpEntity._
import akka.http.model.HttpMethods._
import akka.http.model._
import akka.http.model.headers._
import akka.http.util._
class HttpClientSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside {
implicit val materializer = FlowMaterializer()
"The client implementation" should {
"properly handle a request/response round-trip" which {
"has a request with empty entity" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"has a request with default entity" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendNext(ByteString("XY"))
expectWireData("XY")
sub.sendComplete()
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"has a response with a default entity" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Transfer-Encoding: chunked
|
|""")
responsesSub.request(1)
val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext()
ct shouldEqual ContentTypes.`application/octet-stream`
val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]()
chunks.runWith(Sink(probe))
val sub = probe.expectSubscription()
sendWireData("3\nABC\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("ABC"))
sendWireData("4\nDEFX\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("DEFX"))
sendWireData("0\n\n")
sub.request(1)
probe.expectNext(HttpEntity.LastChunk)
probe.expectComplete()
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"exhibits eager request stream completion" in new TestSetup {
requestsSub.sendNext(HttpRequest())
requestsSub.sendComplete()
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
}
"produce proper errors" which {
"catch the entity stream being shorter than the Content-Length" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendComplete()
val InvalidContentLengthException(info) = netOut.expectError()
info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to 2 bytes less"
netInSub.sendComplete()
responses.expectComplete()
netInSub.expectCancellation()
}
"catch the entity stream being longer than the Content-Length" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendNext(ByteString("XYZ"))
val InvalidContentLengthException(info) = netOut.expectError()
info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to more bytes"
netInSub.sendComplete()
responses.expectComplete()
netInSub.expectCancellation()
}
"catch illegal response starts" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.2 200 OK
|
|""")
val error @ IllegalResponseException(info) = responses.expectError()
info.summary shouldEqual "The server-side HTTP version is not supported"
netOut.expectError(error)
requestsSub.expectCancellation()
}
"catch illegal response chunks" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Transfer-Encoding: chunked
|
|""")
responsesSub.request(1)
val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext()
ct shouldEqual ContentTypes.`application/octet-stream`
val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]()
chunks.runWith(Sink(probe))
val sub = probe.expectSubscription()
sendWireData("3\nABC\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("ABC"))
sendWireData("4\nDEFXX")
sub.request(1)
val error @ EntityStreamException(info) = probe.expectError()
info.summary shouldEqual "Illegal chunk termination"
responses.expectComplete()
netOut.expectComplete()
requestsSub.expectCancellation()
}
"catch a response start truncation" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData("HTTP/1.1 200 OK")
netInSub.sendComplete()
val error @ IllegalResponseException(info) = responses.expectError()
info.summary shouldEqual "Illegal HTTP message start"
netOut.expectError(error)
requestsSub.expectCancellation()
}
}
}
class TestSetup {
val requests = StreamTestKit.PublisherProbe[HttpRequest]
val responses = StreamTestKit.SubscriberProbe[HttpResponse]
val remoteAddress = new InetSocketAddress("example.com", 80)
def settings = ClientConnectionSettings(system)
.copy(userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test")))))
val (netOut, netIn) = {
val netOut = StreamTestKit.SubscriberProbe[ByteString]
val netIn = StreamTestKit.PublisherProbe[ByteString]
val clientFlow = HttpClient.transportToConnectionClientFlow(
Flow(Sink(netOut), Source(netIn)), remoteAddress, settings, NoLogging)
Source(requests).via(clientFlow).runWith(Sink(responses))
netOut -> netIn
}
def wipeDate(string: String) =
string.fastSplit('\n').map {
case s if s.startsWith("Date:") "Date: XXXX\r"
case s s
}.mkString("\n")
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val requestsSub = requests.expectSubscription()
val responsesSub = responses.expectSubscription()
def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII"))
def sendWireData(data: ByteString): Unit = netInSub.sendNext(data)
def expectWireData(s: String) = {
netOutSub.request(1)
netOut.expectNext().utf8String shouldEqual s.stripMarginWithNewline("\r\n")
}
def closeNetworkInput(): Unit = netInSub.sendComplete()
}
}

View file

@ -120,8 +120,9 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|Host: x |Host: x
| |
|ABCDPATCH""" |ABCDPATCH"""
}.toCharArray.map(_.toString).toSeq should rawMultiParseTo( }.toCharArray.map(_.toString).toSeq should generalRawMultiParseTo(
HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes)) Right(HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes)),
Left(MessageStartError(400, ErrorInfo("Illegal HTTP message start"))))
closeAfterResponseCompletion shouldEqual Seq(false) closeAfterResponseCompletion shouldEqual Seq(false)
} }
@ -232,7 +233,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val parser = newParser val parser = newParser
val result = multiParse(newParser)(Seq(prep(start + manyChunks))) val result = multiParse(newParser)(Seq(prep(start + manyChunks)))
val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity
val strictChunks = chunks.collectAll.awaitResult(awaitAtMost) val strictChunks = chunks.grouped(100000).runWith(Sink.head).awaitResult(awaitAtMost)
strictChunks.size shouldEqual numChunks strictChunks.size shouldEqual numChunks
} }
} }
@ -442,7 +443,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] = def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] =
Source(input.toList) Source(input.toList)
.map(ByteString.apply) .map(ByteString.apply)
.section(name("parser"))(_.transform(() parser)) .section(name("parser"))(_.transform(() parser.stage))
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.headAndTail .headAndTail
.collect { .collect {
@ -461,7 +462,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
.map(strictEqualify) .map(strictEqualify)
.collectAll .grouped(100000).runWith(Sink.head)
.awaitResult(awaitAtMost) .awaitResult(awaitAtMost)
protected def parserSettings: ParserSettings = ParserSettings(system) protected def parserSettings: ParserSettings = ParserSettings(system)
@ -474,7 +475,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] = private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
data.collectAll data.grouped(100000).runWith(Sink.head)
.fast.recover { case _: NoSuchElementException Nil } .fast.recover { case _: NoSuchElementException Nil }
def prep(response: String) = response.stripMarginWithNewline("\r\n") def prep(response: String) = response.stripMarginWithNewline("\r\n")

View file

@ -261,7 +261,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val future = val future =
Source(input.toList) Source(input.toList)
.map(ByteString.apply) .map(ByteString.apply)
.section(name("parser"))(_.transform(() newParser(requestMethod))) .section(name("parser"))(_.transform(() newParserStage(requestMethod)))
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.headAndTail .headAndTail
.collect { .collect {
@ -279,14 +279,16 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
.flatten(FlattenStrategy.concat) .flatten(FlattenStrategy.concat)
.map(strictEqualify) .map(strictEqualify)
.grouped(1000).runWith(Sink.head) .grouped(100000).runWith(Sink.head)
Await.result(future, 500.millis) Await.result(future, 500.millis)
} }
def parserSettings: ParserSettings = ParserSettings(system) def parserSettings: ParserSettings = ParserSettings(system)
def newParser(requestMethod: HttpMethod = GET) = {
val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)(), () requestMethod) def newParserStage(requestMethod: HttpMethod = GET) = {
parser val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)())
parser.setRequestMethodForNextResponse(requestMethod)
parser.stage
} }
private def compactEntity(entity: ResponseEntity): Future[ResponseEntity] = private def compactEntity(entity: ResponseEntity): Future[ResponseEntity] =
@ -296,7 +298,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
} }
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Source[ChunkStreamPart]] = private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Source[ChunkStreamPart]] =
data.grouped(1000).runWith(Sink.head) data.grouped(100000).runWith(Sink.head)
.fast.map(source(_: _*)) .fast.map(source(_: _*))
.fast.recover { case _: NoSuchElementException source() } .fast.recover { case _: NoSuchElementException source() }

View file

@ -5,7 +5,7 @@
package akka.http.engine.server package akka.http.engine.server
import scala.concurrent.duration._ import scala.concurrent.duration._
import org.scalatest.{ Inside, BeforeAndAfterAll, Matchers } import org.scalatest.Inside
import akka.event.NoLogging import akka.event.NoLogging
import akka.util.ByteString import akka.util.ByteString
import akka.stream.scaladsl._ import akka.stream.scaladsl._
@ -18,7 +18,7 @@ import HttpEntity._
import MediaTypes._ import MediaTypes._
import HttpMethods._ import HttpMethods._
class HttpServerSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside { class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside {
implicit val materializer = FlowMaterializer() implicit val materializer = FlowMaterializer()
"The server implementation" should { "The server implementation" should {

View file

@ -96,6 +96,6 @@ trait RouteTestResultComponent {
failTest("Request was neither completed nor rejected within " + timeout) failTest("Request was neither completed nor rejected within " + timeout)
private def awaitAllElements[T](data: Source[T]): immutable.Seq[T] = private def awaitAllElements[T](data: Source[T]): immutable.Seq[T] =
data.collectAll.awaitResult(timeout) data.grouped(100000).runWith(Sink.head).awaitResult(timeout)
} }
} }

View file

@ -5,17 +5,16 @@
package akka.http.server package akka.http.server
package directives package directives
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.http.model.StatusCodes._ import akka.http.model.StatusCodes._
import akka.http.model._ import akka.http.model._
import akka.http.model.headers._ import akka.http.model.headers._
import akka.http.util._ import akka.http.util._
import akka.stream.scaladsl.Source import akka.stream.scaladsl.{ Sink, Source }
import akka.util.ByteString import akka.util.ByteString
import org.scalatest.{ Inside, Inspectors } import org.scalatest.{ Inside, Inspectors }
import scala.concurrent.Await
import scala.concurrent.duration._
class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside { class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
lazy val wrs = lazy val wrs =
mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) & mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) &
@ -100,7 +99,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
wrs { complete("Some random and not super short entity.") } wrs { complete("Some random and not super short entity.") }
} ~> check { } ~> check {
header[`Content-Range`] should be(None) header[`Content-Range`] should be(None)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second) val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second)
parts.size shouldEqual 2 parts.size shouldEqual 2
inside(parts(0)) { inside(parts(0)) {
case Multipart.ByteRanges.BodyPart(range, entity, unit, headers) case Multipart.ByteRanges.BodyPart(range, entity, unit, headers)
@ -125,7 +124,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
wrs { complete(HttpEntity.Default(MediaTypes.`text/plain`, content.length, entityData())) } wrs { complete(HttpEntity.Default(MediaTypes.`text/plain`, content.length, entityData())) }
} ~> check { } ~> check {
header[`Content-Range`] should be(None) header[`Content-Range`] should be(None)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second) val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second)
parts.size shouldEqual 2 parts.size shouldEqual 2
} }
} }