=htc refactor HttpClient stream setup, closes #16510

This commit is contained in:
Mathias 2014-12-18 17:05:50 +01:00
parent 4d907bfb50
commit 968e9cc5a7
13 changed files with 687 additions and 127 deletions

View file

@ -5,16 +5,17 @@
package akka.http.engine.client
import java.net.InetSocketAddress
import scala.collection.immutable.Queue
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import akka.stream.stage._
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream.FlattenStrategy
import akka.stream.scaladsl._
import akka.stream.scaladsl.OperationAttributes._
import akka.http.model.{ HttpMethod, HttpRequest, HttpResponse }
import akka.http.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse }
import akka.http.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
import akka.http.engine.parsing.{ HttpHeaderParser, HttpResponseParser }
import akka.http.engine.parsing.ParserOutput._
import akka.http.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
import akka.http.util._
/**
@ -37,39 +38,212 @@ private[http] object HttpClient {
})
val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
val requestMethodByPass = new RequestMethodByPass(remoteAddress)
Flow[HttpRequest]
.map(requestMethodByPass)
/*
Basic Stream Setup
==================
requestIn +----------+
+-----------------------------------------------+--->| Termi- | requestRendering
| | nation +---------------------> |
+-------------------------------------->| Merge | |
| Termination Backchannel | +----------+ | TCP-
| | | level
| | Method | client
| +------------+ | Bypass | flow
responseOut | responsePrep | Response |<---+ |
<------------+----------------| Parsing | |
| Merge |<------------------------------------------ V
+------------+
*/
val requestIn = UndefinedSource[HttpRequest]
val responseOut = UndefinedSink[HttpResponse]
val methodBypassFanout = Broadcast[HttpRequest]
val responseParsingMerge = new ResponseParsingMerge(rootParser)
val terminationFanout = Broadcast[HttpResponse]
val terminationMerge = new TerminationMerge
val requestRendering = Flow[HttpRequest]
.map(RequestRenderingContext(_, remoteAddress))
.section(name("renderer"))(_.transform(() requestRendererFactory.newRenderer))
.flatten(FlattenStrategy.concat)
val transportFlow = Flow[ByteString]
.section(name("errorLogger"))(_.transform(() errorLogger(log, "Outgoing request stream error")))
.via(transport)
.section(name("rootParser"))(_.transform(()
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy(requestMethodByPass)))
.splitWhen(_.isInstanceOf[MessageStart])
val methodBypass = Flow[HttpRequest].map(_.method)
import ParserOutput._
val responsePrep = Flow[List[ResponseOutput]]
.transform(recover { case x: ResponseParsingError x.error :: Nil }) // FIXME after #16565
.mapConcat(identityFunc)
.splitWhen(x x.isInstanceOf[MessageStart] || x == MessageEnd)
.headAndTail
.collect {
case (ResponseStart(statusCode, protocol, headers, createEntity, _), entityParts)
HttpResponse(statusCode, headers, createEntity(entityParts), protocol)
case (MessageStartError(_, info), _) throw IllegalResponseException(info)
}
import FlowGraphImplicits._
Flow() { implicit b
requestIn ~> methodBypassFanout ~> terminationMerge.requestInput ~> requestRendering ~> transportFlow ~>
responseParsingMerge.dataInput ~> responsePrep ~> terminationFanout ~> responseOut
methodBypassFanout ~> methodBypass ~> responseParsingMerge.methodBypassInput
terminationFanout ~> terminationMerge.terminationBackchannelInput
b.allowCycles()
requestIn -> responseOut
}
}
// FIXME: refactor to a pure-stream design that allows us to get rid of this ad-hoc queue here
class RequestMethodByPass(serverAddress: InetSocketAddress)
extends (HttpRequest RequestRenderingContext) with (() HttpMethod) {
private[this] var requestMethods = Queue.empty[HttpMethod]
def apply(request: HttpRequest) = {
requestMethods = requestMethods.enqueue(request.method)
RequestRenderingContext(request, serverAddress)
// a simple merge stage that simply forwards its first input and ignores its second input
// (the terminationBackchannelInput), but applies a special completion handling
class TerminationMerge extends FlexiMerge[HttpRequest] {
import FlexiMerge._
val requestInput = createInputPort[HttpRequest]()
val terminationBackchannelInput = createInputPort[HttpResponse]()
def createMergeLogic() = new MergeLogic[HttpRequest] {
override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"TerminationMerge must have 2 connected inputs, was $inputCount")
Vector(requestInput, terminationBackchannelInput)
}
override def initialState = State[Any](ReadAny(requestInput, terminationBackchannelInput)) {
case (ctx, _, request: HttpRequest) { ctx.emit(request); SameState }
case _ SameState // simply drop all responses, we are only interested in the completion of the response input
}
override def initialCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `requestInput`) SameState
case (ctx, `terminationBackchannelInput`)
ctx.complete()
SameState
},
onError = defaultCompletionHandling.onError)
}
def apply(): HttpMethod =
if (requestMethods.nonEmpty) {
val method = requestMethods.head
requestMethods = requestMethods.tail
method
} else HttpResponseParser.NoMethod
}
import ParserOutput._
/**
* A FlexiMerge that follows this logic:
* 1. Wait on the methodBypass for the method of the request corresponding to the next response to be received
* 2. Read from the dataInput until exactly one response has been fully received
* 3. Go back to 1.
*/
class ResponseParsingMerge(rootParser: HttpResponseParser) extends FlexiMerge[List[ResponseOutput]] {
import FlexiMerge._
val dataInput = createInputPort[ByteString]()
val methodBypassInput = createInputPort[HttpMethod]()
def createMergeLogic() = new MergeLogic[List[ResponseOutput]] {
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
val parser = rootParser.createShallowCopy()
var methodBypassCompleted = false
override def inputHandles(inputCount: Int) = {
require(inputCount == 2, s"ResponseParsingMerge must have 2 connected inputs, was $inputCount")
Vector(dataInput, methodBypassInput)
}
override val initialState: State[HttpMethod] =
State(Read(methodBypassInput)) {
case (ctx, _, method)
parser.setRequestMethodForNextResponse(method)
drainParser(parser.onPush(ByteString.empty), ctx,
onNeedNextMethod = () SameState,
onNeedMoreData = () {
ctx.changeCompletionHandling(responseReadingCompletionHandling)
responseReadingState
})
}
val responseReadingState: State[ByteString] =
State(Read(dataInput)) {
case (ctx, _, bytes)
drainParser(parser.onPush(bytes), ctx,
onNeedNextMethod = () {
if (methodBypassCompleted) {
ctx.complete()
SameState
} else {
ctx.changeCompletionHandling(initialCompletionHandling)
initialState
}
},
onNeedMoreData = () SameState)
}
@tailrec def drainParser(current: ResponseOutput, ctx: MergeLogicContext,
onNeedNextMethod: () State[_], onNeedMoreData: () State[_],
b: ListBuffer[ResponseOutput] = ListBuffer.empty): State[_] = {
def emit(output: List[ResponseOutput]): Unit = if (output.nonEmpty) ctx.emit(output)
current match {
case NeedNextRequestMethod
emit(b.result())
onNeedNextMethod()
case StreamEnd
emit(b.result())
ctx.complete()
SameState
case NeedMoreData
emit(b.result())
onNeedMoreData()
case x drainParser(parser.onPull(), ctx, onNeedNextMethod, onNeedMoreData, b += x)
}
}
override val initialCompletionHandling = CompletionHandling(
onComplete = (ctx, _) { ctx.complete(); SameState },
onError = defaultCompletionHandling.onError)
val responseReadingCompletionHandling = CompletionHandling(
onComplete = {
case (ctx, `methodBypassInput`)
methodBypassCompleted = true
SameState
case (ctx, `dataInput`)
if (parser.onUpstreamFinish()) {
ctx.complete()
} else {
// not pretty but because the FlexiMerge doesn't let us emit from here (#16565)
// we need to funnel the error through the error channel
ctx.error(new ResponseParsingError(parser.onPull().asInstanceOf[ErrorOutput]))
}
SameState
},
onError = defaultCompletionHandling.onError)
}
}
private class ResponseParsingError(val error: ErrorOutput) extends RuntimeException
// TODO: remove after #16394 is cleared
def recover[A, B >: A](pf: PartialFunction[Throwable, B]): () PushPullStage[A, B] = {
val stage = new PushPullStage[A, B] {
var recovery: Option[B] = None
def onPush(elem: A, ctx: Context[B]): Directive = ctx.push(elem)
def onPull(ctx: Context[B]): Directive = recovery match {
case None ctx.pull()
case Some(x) { recovery = null; ctx.push(x) }
case null ctx.finish()
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[B]): TerminationDirective =
if (pf isDefinedAt cause) {
recovery = Some(pf(cause))
ctx.absorbTermination()
} else super.onUpstreamFailure(cause, ctx)
}
() stage
}
}

View file

@ -10,7 +10,6 @@ import akka.parboiled2.CharUtils
import akka.util.ByteString
import akka.stream.scaladsl.Source
import akka.stream.stage._
import akka.http.Http.StreamException
import akka.http.model.parser.CharacterClasses
import akka.http.util._
import akka.http.model._
@ -22,21 +21,33 @@ import ParserOutput._
* INTERNAL API
*/
private[http] abstract class HttpMessageParser[Output >: MessageOutput <: ParserOutput](val settings: ParserSettings,
val headerParser: HttpHeaderParser)
extends PushPullStage[ByteString, Output] {
val headerParser: HttpHeaderParser) { self
import HttpMessageParser._
import settings._
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
private[this] val result = new ListBuffer[Output]
private[this] var state: ByteString StateResult = startNewMessage(_, 0)
private[this] var protocol: HttpProtocol = `HTTP/1.1`
private[this] var completionHandling: CompletionHandling = CompletionOk
private[this] var terminated = false
override def onPush(input: ByteString, ctx: Context[Output]): Directive = {
def isTerminated = terminated
val stage: PushPullStage[ByteString, Output] =
new PushPullStage[ByteString, Output] {
def onPush(elem: ByteString, ctx: Context[Output]) = handleParserOutput(self.onPush(elem), ctx)
def onPull(ctx: Context[Output]) = handleParserOutput(self.onPull(), ctx)
private def handleParserOutput(output: Output, ctx: Context[Output]): Directive =
output match {
case StreamEnd ctx.finish()
case NeedMoreData ctx.pull()
case x ctx.push(x)
}
override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective =
if (self.onUpstreamFinish()) ctx.finish() else ctx.absorbTermination()
}
final def onPush(input: ByteString): Output = {
@tailrec def run(next: ByteString StateResult): StateResult =
(try next(input)
catch {
@ -51,37 +62,32 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
if (result.nonEmpty) throw new IllegalStateException("Unexpected `onPush`")
run(state)
pushResultHeadAndFinishOrPull(ctx)
onPull()
}
def onPull(ctx: Context[Output]): Directive = pushResultHeadAndFinishOrPull(ctx)
def pushResultHeadAndFinishOrPull(ctx: Context[Output]): Directive =
final def onPull(): Output =
if (result.nonEmpty) {
val head = result.head
result.remove(0) // faster than `ListBuffer::drop`
ctx.push(head)
} else if (terminated) ctx.finish() else ctx.pull()
head
} else if (terminated) StreamEnd else NeedMoreData
override def onUpstreamFinish(ctx: Context[Output]) = {
final def onUpstreamFinish(): Boolean = {
completionHandling() match {
case Some(x) emit(x.asInstanceOf[Output])
case Some(x) emit(x)
case None // nothing to do
}
terminated = true
if (result.isEmpty) ctx.finish() else ctx.absorbTermination()
result.isEmpty
}
def startNewMessage(input: ByteString, offset: Int): StateResult = {
def _startNewMessage(input: ByteString, offset: Int): StateResult =
try parseMessage(input, offset)
catch { case NotEnoughDataException continue(input, offset)(_startNewMessage) }
protected final def startNewMessage(input: ByteString, offset: Int): StateResult = {
if (offset < input.length) setCompletionHandling(CompletionIsMessageStartError)
_startNewMessage(input, offset)
try parseMessage(input, offset)
catch { case NotEnoughDataException continue(input, offset)(startNewMessage) }
}
def parseMessage(input: ByteString, offset: Int): StateResult
protected def parseMessage(input: ByteString, offset: Int): StateResult
def parseProtocol(input: ByteString, cursor: Int): Int = {
def c(ix: Int) = byteChar(input, cursor + ix)
@ -204,7 +210,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
val chunkBodyEnd = cursor + chunkSize
def result(terminatorLen: Int) = {
emit(EntityChunk(HttpEntity.Chunk(input.slice(cursor, chunkBodyEnd), extension)))
trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
Trampoline(_ parseChunk(input, chunkBodyEnd + terminatorLen, isLastMessage))
}
byteChar(input, chunkBodyEnd) match {
case '\r' if byteChar(input, chunkBodyEnd + 1) == '\n' result(2)
@ -255,7 +261,6 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
state = next(_, 0)
done()
}
def trampoline(next: ByteString StateResult): StateResult = Trampoline(next)
def failMessageStart(summary: String): StateResult = failMessageStart(summary, "")
def failMessageStart(summary: String, detail: String): StateResult = failMessageStart(StatusCodes.BadRequest, summary, detail)
@ -299,7 +304,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
transformData: Source[ByteString] Source[ByteString] = identityFunc)(entityParts: Source[_ <: ParserOutput]): UniversalEntity = {
val data = entityParts.collect {
case EntityPart(bytes) bytes
case EntityStreamError(info) throw new StreamException(info)
case EntityStreamError(info) throw EntityStreamException(info)
}
HttpEntity.Default(contentType(cth), contentLength, transformData(data))
}
@ -308,7 +313,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
transformChunks: Source[HttpEntity.ChunkStreamPart] Source[HttpEntity.ChunkStreamPart] = identityFunc)(entityChunks: Source[_ <: ParserOutput]): RequestEntity = {
val chunks = entityChunks.collect {
case EntityChunk(chunk) chunk
case EntityStreamError(info) throw new StreamException(info)
case EntityStreamError(info) throw EntityStreamException(info)
}
HttpEntity.Chunked(contentType(cth), transformChunks(chunks))
}
@ -324,7 +329,10 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
}
private[http] object HttpMessageParser {
type CompletionHandling = () Option[ParserOutput]
sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup
final case class Trampoline(f: ByteString StateResult) extends StateResult
type CompletionHandling = () Option[ErrorOutput]
val CompletionOk: CompletionHandling = () None
val CompletionIsMessageStartError: CompletionHandling =
() Some(ParserOutput.MessageStartError(StatusCodes.BadRequest, ErrorInfo("Illegal HTTP message start")))

View file

@ -11,9 +11,9 @@ import akka.stream.scaladsl.OperationAttributes._
import akka.stream.stage.{ Context, PushPullStage }
import akka.stream.scaladsl.Source
import akka.util.ByteString
import akka.http.engine.server.OneHundredContinue
import akka.http.model.parser.CharacterClasses
import akka.http.util.identityFunc
import akka.http.engine.TokenSourceActor
import akka.http.model._
import headers._
import StatusCodes._
@ -27,6 +27,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
_headerParser: HttpHeaderParser,
oneHundredContinueRef: () Option[ActorRef] = () None)
extends HttpMessageParser[RequestOutput](_settings, _headerParser) {
import HttpMessageParser._
import settings._
private[this] var method: HttpMethod = _
@ -105,7 +106,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
uriBytes = input.iterator.slice(uriStart, uriEnd).toArray[Byte] // TODO: can we reduce allocations here?
uri = Uri.parseHttpRequestTarget(uriBytes, mode = uriParsingMode)
} catch {
case e: IllegalUriException throw new ParsingException(BadRequest, e.info)
case IllegalUriException(info) throw new ParsingException(BadRequest, info)
}
uriEnd + 1
}
@ -133,7 +134,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
def onPull(ctx: Context[T]) = {
if (!oneHundredContinueSent) {
val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable"))
ref ! OneHundredContinue
ref ! TokenSourceActor.Trigger
oneHundredContinueSent = true
}
ctx.pull()

View file

@ -10,38 +10,41 @@ import akka.stream.scaladsl.Source
import akka.util.ByteString
import akka.http.model._
import headers._
import HttpResponseParser.NoMethod
import ParserOutput._
/**
* INTERNAL API
*/
private[http] class HttpResponseParser(_settings: ParserSettings,
_headerParser: HttpHeaderParser,
dequeueRequestMethodForNextResponse: () HttpMethod = () NoMethod)
private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: HttpHeaderParser)
extends HttpMessageParser[ResponseOutput](_settings, _headerParser) {
import HttpMessageParser._
import settings._
private[this] var requestMethodForCurrentResponse: HttpMethod = NoMethod
private[this] var requestMethodForCurrentResponse: Option[HttpMethod] = None
private[this] var statusCode: StatusCode = StatusCodes.OK
def createShallowCopy(dequeueRequestMethodForNextResponse: () HttpMethod): HttpResponseParser =
new HttpResponseParser(settings, headerParser.createShallowCopy(), dequeueRequestMethodForNextResponse)
def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy())
override def startNewMessage(input: ByteString, offset: Int): StateResult = {
requestMethodForCurrentResponse = dequeueRequestMethodForNextResponse()
super.startNewMessage(input, offset)
}
def setRequestMethodForNextResponse(requestMethod: HttpMethod): Unit =
if (requestMethodForCurrentResponse.isEmpty) requestMethodForCurrentResponse = Some(requestMethod)
def parseMessage(input: ByteString, offset: Int): StateResult =
if (requestMethodForCurrentResponse ne NoMethod) {
protected def parseMessage(input: ByteString, offset: Int): StateResult =
if (requestMethodForCurrentResponse.isDefined) {
var cursor = parseProtocol(input, offset)
if (byteChar(input, cursor) == ' ') {
cursor = parseStatusCode(input, cursor + 1)
cursor = parseReason(input, cursor)()
parseHeaderLines(input, cursor)
} else badProtocol
} else failMessageStart("Unexpected server response", input.drop(offset).utf8String)
} else {
emit(NeedNextRequestMethod)
done()
}
override def emit(output: ResponseOutput): Unit = {
if (output == MessageEnd) requestMethodForCurrentResponse = None
super.emit(output)
}
def badProtocol = throw new ParsingException("The server-side HTTP version is not supported")
@ -81,10 +84,11 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
def finishEmptyResponse() = {
emitResponseStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart)
}
if (statusCode.allowsEntity && (requestMethodForCurrentResponse ne HttpMethods.HEAD)) {
if (statusCode.allowsEntity && (requestMethodForCurrentResponse.get != HttpMethods.HEAD)) {
teh match {
case None clh match {
case Some(`Content-Length`(contentLength))
@ -95,6 +99,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
val cl = contentLength.toInt
emitResponseStart(strictEntity(cth, input, bodyStart, cl))
setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart + cl)
} else {
emitResponseStart(defaultEntity(cth, contentLength))
@ -128,11 +133,4 @@ private[http] class HttpResponseParser(_settings: ParserSettings,
emit(EntityPart(input drop bodyStart))
continue(parseToCloseBody)
}
}
/**
* INTERNAL API
*/
private[http] object HttpResponseParser {
val NoMethod = HttpMethod.custom("NONE", safe = false, idempotent = false, entityAccepted = false)
}

View file

@ -21,6 +21,7 @@ private[http] object ParserOutput {
sealed trait ResponseOutput extends ParserOutput
sealed trait MessageStart extends ParserOutput
sealed trait MessageOutput extends RequestOutput with ResponseOutput
sealed trait ErrorOutput extends MessageOutput
final case class RequestStart(
method: HttpMethod,
@ -44,7 +45,15 @@ private[http] object ParserOutput {
final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput
final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with MessageOutput
final case class MessageStartError(status: StatusCode, info: ErrorInfo) extends MessageStart with ErrorOutput
final case class EntityStreamError(info: ErrorInfo) extends MessageOutput
final case class EntityStreamError(info: ErrorInfo) extends ErrorOutput
//////////// meta messages ///////////
case object StreamEnd extends MessageOutput
case object NeedMoreData extends MessageOutput
case object NeedNextRequestMethod extends ResponseOutput
}

View file

@ -55,31 +55,25 @@ package object util {
.flatten(FlattenStrategy.concat)
}
private[http] implicit class EnhancedSource[T](val underlying: Source[T]) {
def printEvent(marker: String): Source[T] =
underlying.transform(() new PushStage[T, T] {
override def onPush(element: T, ctx: Context[T]): Directive = {
println(s"$marker: $element")
ctx.push(element)
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
println(s"$marker: Failure $cause")
super.onUpstreamFailure(cause, ctx)
}
override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = {
println(s"$marker: Terminated")
super.onUpstreamFinish(ctx)
}
})
/**
* Drain this stream into a Vector and provide it as a future value.
*
* FIXME: Should be part of akka-streams
*/
def collectAll(implicit materializer: FlowMaterializer): Future[immutable.Seq[T]] =
underlying.fold(Vector.empty[T])(_ :+ _)
}
def printEvent[T](marker: String): Flow[T, T] =
Flow[T].transform(() new PushStage[T, T] {
override def onPush(element: T, ctx: Context[T]): Directive = {
println(s"$marker: $element")
ctx.push(element)
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
println(s"$marker: Error $cause")
super.onUpstreamFailure(cause, ctx)
}
override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = {
println(s"$marker: Complete")
super.onUpstreamFinish(ctx)
}
override def onDownstreamFinish(ctx: Context[T]): TerminationDirective = {
println(s"$marker: Cancel")
super.onDownstreamFinish(ctx)
}
})
private[http] implicit class AddFutureAwaitResult[T](future: Future[T]) {
/** "Safe" Await.result that doesn't throw away half of the stacktrace */

View file

@ -36,7 +36,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
implicit val materializer = FlowMaterializer()
"The server-side HTTP infrastructure" should {
"The low-level HTTP infrastructure" should {
"properly bind a server" in {
val (hostname, port) = temporaryServerHostnameAndPort()
@ -70,6 +70,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
val (serverIn, serverOut) = acceptConnection()
val clientOutSub = clientOut.expectSubscription()
clientOutSub.expectRequest()
clientOutSub.sendNext(HttpRequest(uri = "/abc"))
val serverInSub = serverIn.expectSubscription()
@ -77,12 +78,20 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc")
val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(entity = "yeah"))
val clientInSub = clientIn.expectSubscription()
clientInSub.request(1)
val response = clientIn.expectNext()
toStrict(response.entity) shouldEqual HttpEntity("yeah")
clientOutSub.sendComplete()
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
}
"properly complete a chunked request/response cycle" in new TestSetup {
@ -104,6 +113,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
Await.result(chunkStream.grouped(4).runWith(Sink.head), 100.millis) shouldEqual chunks
val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(206, List(RawHeader("Age", "42")), chunkedEntity))
val clientInSub = clientIn.expectSubscription()
@ -111,8 +121,42 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll {
val HttpResponse(StatusCodes.PartialContent, List(RawHeader("Age", "42"), Server(_), Date(_)),
Chunked(`chunkedContentType`, chunkStream2), HttpProtocols.`HTTP/1.1`) = clientIn.expectNext()
Await.result(chunkStream2.grouped(1000).runWith(Sink.head), 100.millis) shouldEqual chunks
clientOutSub.sendComplete()
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
}
"be able to deal with eager closing of the request stream on the client side" in new TestSetup {
val (clientOut, clientIn) = openNewClientConnection()
val (serverIn, serverOut) = acceptConnection()
val clientOutSub = clientOut.expectSubscription()
clientOutSub.sendNext(HttpRequest(uri = "/abc"))
clientOutSub.sendComplete() // complete early
val serverInSub = serverIn.expectSubscription()
serverInSub.request(1)
serverIn.expectNext().uri shouldEqual Uri(s"http://$hostname:$port/abc")
val serverOutSub = serverOut.expectSubscription()
serverOutSub.expectRequest()
serverOutSub.sendNext(HttpResponse(entity = "yeah"))
val clientInSub = clientIn.expectSubscription()
clientInSub.request(1)
val response = clientIn.expectNext()
toStrict(response.entity) shouldEqual HttpEntity("yeah")
serverInSub.request(1) // work-around for #16552
serverIn.expectComplete()
serverOutSub.expectCancellation()
clientInSub.request(1) // work-around for #16552
clientIn.expectComplete()
}
}
override def afterAll() = system.shutdown()

View file

@ -0,0 +1,330 @@
/*
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.client
import java.net.InetSocketAddress
import org.scalatest.Inside
import akka.util.ByteString
import akka.event.NoLogging
import akka.stream.FlowMaterializer
import akka.stream.testkit.{ AkkaSpec, StreamTestKit }
import akka.stream.scaladsl._
import akka.http.model.HttpEntity._
import akka.http.model.HttpMethods._
import akka.http.model._
import akka.http.model.headers._
import akka.http.util._
class HttpClientSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside {
implicit val materializer = FlowMaterializer()
"The client implementation" should {
"properly handle a request/response round-trip" which {
"has a request with empty entity" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"has a request with default entity" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendNext(ByteString("XY"))
expectWireData("XY")
sub.sendComplete()
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"has a response with a default entity" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Transfer-Encoding: chunked
|
|""")
responsesSub.request(1)
val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext()
ct shouldEqual ContentTypes.`application/octet-stream`
val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]()
chunks.runWith(Sink(probe))
val sub = probe.expectSubscription()
sendWireData("3\nABC\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("ABC"))
sendWireData("4\nDEFX\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("DEFX"))
sendWireData("0\n\n")
sub.request(1)
probe.expectNext(HttpEntity.LastChunk)
probe.expectComplete()
requestsSub.sendComplete()
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
"exhibits eager request stream completion" in new TestSetup {
requestsSub.sendNext(HttpRequest())
requestsSub.sendComplete()
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Content-Length: 0
|
|""")
responsesSub.request(1)
responses.expectNext(HttpResponse())
netOut.expectComplete()
netInSub.sendComplete()
responses.expectComplete()
}
}
"produce proper errors" which {
"catch the entity stream being shorter than the Content-Length" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendComplete()
val InvalidContentLengthException(info) = netOut.expectError()
info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to 2 bytes less"
netInSub.sendComplete()
responses.expectComplete()
netInSub.expectCancellation()
}
"catch the entity stream being longer than the Content-Length" in new TestSetup {
val probe = StreamTestKit.PublisherProbe[ByteString]()
requestsSub.sendNext(HttpRequest(PUT, entity = HttpEntity(ContentTypes.`application/octet-stream`, 8, Source(probe))))
expectWireData(
"""PUT / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|Content-Type: application/octet-stream
|Content-Length: 8
|
|""")
val sub = probe.expectSubscription()
sub.expectRequest(4)
sub.sendNext(ByteString("ABC"))
expectWireData("ABC")
sub.sendNext(ByteString("DEF"))
expectWireData("DEF")
sub.sendNext(ByteString("XYZ"))
val InvalidContentLengthException(info) = netOut.expectError()
info.summary shouldEqual "HTTP message had declared Content-Length 8 but entity data stream amounts to more bytes"
netInSub.sendComplete()
responses.expectComplete()
netInSub.expectCancellation()
}
"catch illegal response starts" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.2 200 OK
|
|""")
val error @ IllegalResponseException(info) = responses.expectError()
info.summary shouldEqual "The server-side HTTP version is not supported"
netOut.expectError(error)
requestsSub.expectCancellation()
}
"catch illegal response chunks" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData(
"""HTTP/1.1 200 OK
|Transfer-Encoding: chunked
|
|""")
responsesSub.request(1)
val HttpResponse(_, _, HttpEntity.Chunked(ct, chunks), _) = responses.expectNext()
ct shouldEqual ContentTypes.`application/octet-stream`
val probe = StreamTestKit.SubscriberProbe[ChunkStreamPart]()
chunks.runWith(Sink(probe))
val sub = probe.expectSubscription()
sendWireData("3\nABC\n")
sub.request(1)
probe.expectNext(HttpEntity.Chunk("ABC"))
sendWireData("4\nDEFXX")
sub.request(1)
val error @ EntityStreamException(info) = probe.expectError()
info.summary shouldEqual "Illegal chunk termination"
responses.expectComplete()
netOut.expectComplete()
requestsSub.expectCancellation()
}
"catch a response start truncation" in new TestSetup {
requestsSub.sendNext(HttpRequest())
expectWireData(
"""GET / HTTP/1.1
|Host: example.com:80
|User-Agent: akka-http/test
|
|""")
netInSub.expectRequest(16)
sendWireData("HTTP/1.1 200 OK")
netInSub.sendComplete()
val error @ IllegalResponseException(info) = responses.expectError()
info.summary shouldEqual "Illegal HTTP message start"
netOut.expectError(error)
requestsSub.expectCancellation()
}
}
}
class TestSetup {
val requests = StreamTestKit.PublisherProbe[HttpRequest]
val responses = StreamTestKit.SubscriberProbe[HttpResponse]
val remoteAddress = new InetSocketAddress("example.com", 80)
def settings = ClientConnectionSettings(system)
.copy(userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test")))))
val (netOut, netIn) = {
val netOut = StreamTestKit.SubscriberProbe[ByteString]
val netIn = StreamTestKit.PublisherProbe[ByteString]
val clientFlow = HttpClient.transportToConnectionClientFlow(
Flow(Sink(netOut), Source(netIn)), remoteAddress, settings, NoLogging)
Source(requests).via(clientFlow).runWith(Sink(responses))
netOut -> netIn
}
def wipeDate(string: String) =
string.fastSplit('\n').map {
case s if s.startsWith("Date:") "Date: XXXX\r"
case s s
}.mkString("\n")
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val requestsSub = requests.expectSubscription()
val responsesSub = responses.expectSubscription()
def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII"))
def sendWireData(data: ByteString): Unit = netInSub.sendNext(data)
def expectWireData(s: String) = {
netOutSub.request(1)
netOut.expectNext().utf8String shouldEqual s.stripMarginWithNewline("\r\n")
}
def closeNetworkInput(): Unit = netInSub.sendComplete()
}
}

View file

@ -120,8 +120,9 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
|Host: x
|
|ABCDPATCH"""
}.toCharArray.map(_.toString).toSeq should rawMultiParseTo(
HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes))
}.toCharArray.map(_.toString).toSeq should generalRawMultiParseTo(
Right(HttpRequest(PUT, "/resource/yes", List(Host("x")), "ABCD".getBytes)),
Left(MessageStartError(400, ErrorInfo("Illegal HTTP message start"))))
closeAfterResponseCompletion shouldEqual Seq(false)
}
@ -232,7 +233,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val parser = newParser
val result = multiParse(newParser)(Seq(prep(start + manyChunks)))
val HttpEntity.Chunked(_, chunks) = result.head.right.get.req.entity
val strictChunks = chunks.collectAll.awaitResult(awaitAtMost)
val strictChunks = chunks.grouped(100000).runWith(Sink.head).awaitResult(awaitAtMost)
strictChunks.size shouldEqual numChunks
}
}
@ -442,7 +443,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] =
Source(input.toList)
.map(ByteString.apply)
.section(name("parser"))(_.transform(() parser))
.section(name("parser"))(_.transform(() parser.stage))
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.headAndTail
.collect {
@ -461,7 +462,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.collectAll
.grouped(100000).runWith(Sink.head)
.awaitResult(awaitAtMost)
protected def parserSettings: ParserSettings = ParserSettings(system)
@ -474,7 +475,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
}
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Seq[ChunkStreamPart]] =
data.collectAll
data.grouped(100000).runWith(Sink.head)
.fast.recover { case _: NoSuchElementException Nil }
def prep(response: String) = response.stripMarginWithNewline("\r\n")

View file

@ -261,7 +261,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val future =
Source(input.toList)
.map(ByteString.apply)
.section(name("parser"))(_.transform(() newParser(requestMethod)))
.section(name("parser"))(_.transform(() newParserStage(requestMethod)))
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.headAndTail
.collect {
@ -279,14 +279,16 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
}
.flatten(FlattenStrategy.concat)
.map(strictEqualify)
.grouped(1000).runWith(Sink.head)
.grouped(100000).runWith(Sink.head)
Await.result(future, 500.millis)
}
def parserSettings: ParserSettings = ParserSettings(system)
def newParser(requestMethod: HttpMethod = GET) = {
val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)(), () requestMethod)
parser
def newParserStage(requestMethod: HttpMethod = GET) = {
val parser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings)())
parser.setRequestMethodForNextResponse(requestMethod)
parser.stage
}
private def compactEntity(entity: ResponseEntity): Future[ResponseEntity] =
@ -296,7 +298,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
}
private def compactEntityChunks(data: Source[ChunkStreamPart]): Future[Source[ChunkStreamPart]] =
data.grouped(1000).runWith(Sink.head)
data.grouped(100000).runWith(Sink.head)
.fast.map(source(_: _*))
.fast.recover { case _: NoSuchElementException source() }

View file

@ -5,7 +5,7 @@
package akka.http.engine.server
import scala.concurrent.duration._
import org.scalatest.{ Inside, BeforeAndAfterAll, Matchers }
import org.scalatest.Inside
import akka.event.NoLogging
import akka.util.ByteString
import akka.stream.scaladsl._
@ -18,7 +18,7 @@ import HttpEntity._
import MediaTypes._
import HttpMethods._
class HttpServerSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside {
class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") with Inside {
implicit val materializer = FlowMaterializer()
"The server implementation" should {

View file

@ -96,6 +96,6 @@ trait RouteTestResultComponent {
failTest("Request was neither completed nor rejected within " + timeout)
private def awaitAllElements[T](data: Source[T]): immutable.Seq[T] =
data.collectAll.awaitResult(timeout)
data.grouped(100000).runWith(Sink.head).awaitResult(timeout)
}
}

View file

@ -5,17 +5,16 @@
package akka.http.server
package directives
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.http.model.StatusCodes._
import akka.http.model._
import akka.http.model.headers._
import akka.http.util._
import akka.stream.scaladsl.Source
import akka.stream.scaladsl.{ Sink, Source }
import akka.util.ByteString
import org.scalatest.{ Inside, Inspectors }
import scala.concurrent.Await
import scala.concurrent.duration._
class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
lazy val wrs =
mapSettings(_.copy(rangeCountLimit = 10, rangeCoalescingThreshold = 1L)) &
@ -100,7 +99,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
wrs { complete("Some random and not super short entity.") }
} ~> check {
header[`Content-Range`] should be(None)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second)
parts.size shouldEqual 2
inside(parts(0)) {
case Multipart.ByteRanges.BodyPart(range, entity, unit, headers)
@ -125,7 +124,7 @@ class RangeDirectivesSpec extends RoutingSpec with Inspectors with Inside {
wrs { complete(HttpEntity.Default(MediaTypes.`text/plain`, content.length, entityData())) }
} ~> check {
header[`Content-Range`] should be(None)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.collectAll, 1.second)
val parts = Await.result(responseAs[Multipart.ByteRanges].parts.grouped(1000).runWith(Sink.head), 1.second)
parts.size shouldEqual 2
}
}