diff --git a/akka-http-core/src/main/resources/reference.conf b/akka-http-core/src/main/resources/reference.conf index 7c0a147b21..35ae9ead17 100644 --- a/akka-http-core/src/main/resources/reference.conf +++ b/akka-http-core/src/main/resources/reference.conf @@ -29,6 +29,18 @@ akka.http { # deliver an unlimited backpressured stream of incoming connections. max-connections = 1024 + # The maximum number of requests that are accepted (and dispatched to + # the application) on one single connection before the first request + # has to be completed. + # Incoming requests that would cause the pipelining limit to be exceeded + # are not read from the connections socket so as to build up "back-pressure" + # to the client via TCP flow control. + # A setting of 1 disables HTTP pipelining, since only one request per + # connection can be "open" (i.e. being processed by the application) at any + # time. Set to higher values to enable HTTP pipelining. + # This value must be > 0 and <= 1024. + pipelining-limit = 16 + # Enables/disables the addition of a `Remote-Address` header # holding the clients (remote) IP address. remote-address-header = off diff --git a/akka-http-core/src/main/scala/akka/http/ServerSettings.scala b/akka-http-core/src/main/scala/akka/http/ServerSettings.scala index 7d73c46f68..fd2a887b4d 100644 --- a/akka-http-core/src/main/scala/akka/http/ServerSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/ServerSettings.scala @@ -26,6 +26,7 @@ final case class ServerSettings( serverHeader: Option[Server], timeouts: ServerSettings.Timeouts, maxConnections: Int, + pipeliningLimit: Int, remoteAddressHeader: Boolean, rawRequestUriHeader: Boolean, transparentHeadRequests: Boolean, @@ -38,6 +39,7 @@ final case class ServerSettings( parserSettings: ParserSettings) { require(0 < maxConnections, "max-connections must be > 0") + require(0 < pipeliningLimit && pipeliningLimit <= 1024, "pipelining-limit must be > 0 and <= 1024") require(0 < responseHeaderSizeHint, "response-size-hint must be > 0") require(0 < backlog, "backlog must be > 0") } @@ -55,6 +57,7 @@ object ServerSettings extends SettingsCompanion[ServerSettings]("akka.http.serve c getPotentiallyInfiniteDuration "idle-timeout", c getFiniteDuration "bind-timeout"), c getInt "max-connections", + c getInt "pipelining-limit", c getBoolean "remote-address-header", c getBoolean "raw-request-uri-header", c getBoolean "transparent-head-requests", diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala index 1947f58c7a..af1a3830bd 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala @@ -4,16 +4,13 @@ package akka.http.impl.engine.parsing -import akka.http.ParserSettings - import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import akka.parboiled2.CharUtils import akka.util.ByteString -import akka.stream.scaladsl.Source import akka.stream.stage._ import akka.http.impl.model.parser.CharacterClasses -import akka.http.impl.util._ +import akka.http.ParserSettings import akka.http.scaladsl.model._ import headers._ import HttpProtocols._ @@ -300,31 +297,30 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser case None ⇒ ContentTypes.`application/octet-stream` } - def emptyEntity(cth: Option[`Content-Type`])(entityParts: Any): UniversalEntity = - if (cth.isDefined) HttpEntity.empty(cth.get.contentType) else HttpEntity.Empty + def emptyEntity(cth: Option[`Content-Type`]) = + StrictEntityCreator(if (cth.isDefined) HttpEntity.empty(cth.get.contentType) else HttpEntity.Empty) def strictEntity(cth: Option[`Content-Type`], input: ByteString, bodyStart: Int, - contentLength: Int)(entityParts: Any): UniversalEntity = - HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength)) + contentLength: Int) = + StrictEntityCreator(HttpEntity.Strict(contentType(cth), input.slice(bodyStart, bodyStart + contentLength))) - def defaultEntity(cth: Option[`Content-Type`], - contentLength: Long, - transformData: Source[ByteString, Unit] ⇒ Source[ByteString, Unit] = identityFunc)(entityParts: Source[_ <: ParserOutput, Unit]): UniversalEntity = { - val data = entityParts.collect { - case EntityPart(bytes) ⇒ bytes - case EntityStreamError(info) ⇒ throw EntityStreamException(info) + def defaultEntity[A <: ParserOutput](cth: Option[`Content-Type`], contentLength: Long) = + StreamedEntityCreator[A, UniversalEntity] { entityParts ⇒ + val data = entityParts.collect { + case EntityPart(bytes) ⇒ bytes + case EntityStreamError(info) ⇒ throw EntityStreamException(info) + } + HttpEntity.Default(contentType(cth), contentLength, HttpEntity.limitableByteSource(data)) } - HttpEntity.Default(contentType(cth), contentLength, HttpEntity.limitableByteSource(transformData(data))) - } - def chunkedEntity(cth: Option[`Content-Type`], - transformChunks: Source[HttpEntity.ChunkStreamPart, Unit] ⇒ Source[HttpEntity.ChunkStreamPart, Unit] = identityFunc)(entityChunks: Source[_ <: ParserOutput, Unit]): RequestEntity = { - val chunks = entityChunks.collect { - case EntityChunk(chunk) ⇒ chunk - case EntityStreamError(info) ⇒ throw EntityStreamException(info) + def chunkedEntity[A <: ParserOutput](cth: Option[`Content-Type`]) = + StreamedEntityCreator[A, RequestEntity] { entityChunks ⇒ + val chunks = entityChunks.collect { + case EntityChunk(chunk) ⇒ chunk + case EntityStreamError(info) ⇒ throw EntityStreamException(info) + } + HttpEntity.Chunked(contentType(cth), HttpEntity.limitableChunkSource(chunks)) } - HttpEntity.Chunked(contentType(cth), HttpEntity.limitableChunkSource(transformChunks(chunks))) - } def addTransferEncodingWithChunkedPeeled(headers: List[HttpHeader], teh: `Transfer-Encoding`): List[HttpHeader] = teh.withChunkedPeeled match { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index 3301b13a68..951f010bec 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -5,18 +5,11 @@ package akka.http.impl.engine.parsing import java.lang.{ StringBuilder ⇒ JStringBuilder } -import akka.http.ParserSettings - import scala.annotation.tailrec -import akka.actor.ActorRef -import akka.stream.stage.{ Context, PushPullStage } -import akka.stream.scaladsl.Flow -import akka.stream.scaladsl.Source +import akka.http.ParserSettings import akka.util.ByteString import akka.http.impl.engine.ws.Handshake import akka.http.impl.model.parser.CharacterClasses -import akka.http.impl.util.identityFunc -import akka.http.impl.engine.TokenSourceActor import akka.http.scaladsl.model._ import headers._ import StatusCodes._ @@ -27,8 +20,7 @@ import ParserOutput._ */ private[http] class HttpRequestParser(_settings: ParserSettings, rawRequestUriHeader: Boolean, - _headerParser: HttpHeaderParser, - oneHundredContinueRef: () ⇒ Option[ActorRef] = () ⇒ None) + _headerParser: HttpHeaderParser) extends HttpMessageParser[RequestOutput](_settings, _headerParser) { import HttpMessageParser._ import settings._ @@ -37,8 +29,8 @@ private[http] class HttpRequestParser(_settings: ParserSettings, private[this] var uri: Uri = _ private[this] var uriBytes: Array[Byte] = _ - def createShallowCopy(oneHundredContinueRef: () ⇒ Option[ActorRef]): HttpRequestParser = - new HttpRequestParser(settings, rawRequestUriHeader, headerParser.createShallowCopy(), oneHundredContinueRef) + def createShallowCopy(): HttpRequestParser = + new HttpRequestParser(settings, rawRequestUriHeader, headerParser.createShallowCopy()) def parseMessage(input: ByteString, offset: Int): StateResult = { var cursor = parseMethod(input, offset) @@ -121,7 +113,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { - def emitRequestStart(createEntity: Source[RequestOutput, Unit] ⇒ RequestEntity, + def emitRequestStart(createEntity: EntityCreator[RequestOutput, RequestEntity], headers: List[HttpHeader] = headers) = { val allHeaders0 = if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers @@ -138,22 +130,6 @@ private[http] class HttpRequestParser(_settings: ParserSettings, emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion)) } - def expect100continueHandling[T, Mat]: Source[T, Mat] ⇒ Source[T, Mat] = - if (expect100continue) { - _.via(Flow[T].transform(() ⇒ new PushPullStage[T, T] { - private var oneHundredContinueSent = false - def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - def onPull(ctx: Context[T]) = { - if (!oneHundredContinueSent) { - val ref = oneHundredContinueRef().getOrElse(throw new IllegalStateException("oneHundredContinueRef unavailable")) - ref ! TokenSourceActor.Trigger - oneHundredContinueSent = true - } - ctx.pull() - } - }).named("expect100continueTrigger")) - } else identityFunc - teh match { case None ⇒ val contentLength = clh match { @@ -172,7 +148,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, setCompletionHandling(HttpMessageParser.CompletionOk) startNewMessage(input, bodyStart + cl) } else { - emitRequestStart(defaultEntity(cth, contentLength, expect100continueHandling)) + emitRequestStart(defaultEntity(cth, contentLength)) parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) } @@ -183,7 +159,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, val completedHeaders = addTransferEncodingWithChunkedPeeled(headers, te) if (te.isChunked) { if (clh.isEmpty) { - emitRequestStart(chunkedEntity(cth, expect100continueHandling), completedHeaders) + emitRequestStart(chunkedEntity(cth), completedHeaders) parseChunk(input, bodyStart, closeAfterResponseCompletion, totalBytesRead = 0L) } else failMessageStart("A chunked request must not contain a Content-Length header.") } else parseEntity(completedHeaders, protocol, input, bodyStart, clh, cth, teh = None, diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala index bbdf4ce020..353e53cb32 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpResponseParser.scala @@ -4,11 +4,9 @@ package akka.http.impl.engine.parsing -import akka.http.ParserSettings - import scala.annotation.tailrec +import akka.http.ParserSettings import akka.http.impl.model.parser.CharacterClasses -import akka.stream.scaladsl.Source import akka.util.ByteString import akka.http.scaladsl.model._ import headers._ @@ -84,7 +82,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int, clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`], expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = { - def emitResponseStart(createEntity: Source[ResponseOutput, Unit] ⇒ ResponseEntity, + def emitResponseStart(createEntity: EntityCreator[ResponseOutput, ResponseEntity], headers: List[HttpHeader] = headers) = emit(ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion)) def finishEmptyResponse() = { @@ -110,9 +108,11 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: parseFixedLengthBody(contentLength, closeAfterResponseCompletion)(input, bodyStart) } case None ⇒ - emitResponseStart { entityParts ⇒ - val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes } - HttpEntity.CloseDelimited(contentType(cth), HttpEntity.limitableByteSource(data)) + emitResponseStart { + StreamedEntityCreator { entityParts ⇒ + val data = entityParts.collect { case EntityPart(bytes) ⇒ bytes } + HttpEntity.CloseDelimited(contentType(cth), HttpEntity.limitableByteSource(data)) + } } setCompletionHandling(HttpMessageParser.CompletionOk) parseToCloseBody(input, bodyStart, totalBytesRead = 0) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala index bafc7e9ffe..fe6a813a8b 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala @@ -28,15 +28,15 @@ private[http] object ParserOutput { uri: Uri, protocol: HttpProtocol, headers: List[HttpHeader], - createEntity: Source[RequestOutput, Unit] ⇒ RequestEntity, - expect100ContinueResponsePending: Boolean, + createEntity: EntityCreator[RequestOutput, RequestEntity], + expect100Continue: Boolean, closeRequested: Boolean) extends MessageStart with RequestOutput final case class ResponseStart( statusCode: StatusCode, protocol: HttpProtocol, headers: List[HttpHeader], - createEntity: Source[ResponseOutput, Unit] ⇒ ResponseEntity, + createEntity: EntityCreator[ResponseOutput, ResponseEntity], closeRequested: Boolean) extends MessageStart with ResponseOutput case object MessageEnd extends MessageOutput @@ -58,4 +58,16 @@ private[http] object ParserOutput { case object NeedNextRequestMethod extends ResponseOutput final case class RemainingBytes(bytes: ByteString) extends ResponseOutput + + ////////////////////////////////////// + + sealed abstract class EntityCreator[-A <: ParserOutput, +B >: HttpEntity.Strict <: HttpEntity] extends (Source[A, Unit] ⇒ B) + + final case class StrictEntityCreator(entity: HttpEntity.Strict) extends EntityCreator[ParserOutput, HttpEntity.Strict] { + def apply(parts: Source[ParserOutput, Unit]) = entity + } + final case class StreamedEntityCreator[-A <: ParserOutput, +B >: HttpEntity.Strict <: HttpEntity](creator: Source[A, Unit] ⇒ B) + extends EntityCreator[A, B] { + def apply(parts: Source[A, Unit]) = creator(parts) + } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index cf1f36a537..90ff251ab4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -6,12 +6,12 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress import java.util.Random +import scala.collection.immutable +import akka.stream.scaladsl.One2OneBidiFlow.One2OneBidi import org.reactivestreams.{ Publisher, Subscriber } import scala.util.control.NonFatal -import akka.actor.{ ActorRef, Deploy, Props } import akka.event.LoggingAdapter import akka.http.ServerSettings -import akka.http.impl.engine.TokenSourceActor import akka.http.impl.engine.parsing.ParserOutput._ import akka.http.impl.engine.parsing._ import akka.http.impl.engine.rendering.{ HttpResponseRendererFactory, ResponseRenderingContext, ResponseRenderingOutput } @@ -29,6 +29,24 @@ import akka.util.ByteString /** * INTERNAL API + * + * + * HTTP pipeline setup (without the underlying SSL/TLS (un)wrapping and the websocket switch): + * + * +----------+ +-------------+ +-------------+ +-----------+ + * HttpRequest | | Http- | request- | Request- | | Request- | request- | ByteString + * | <------------+ <----------+ Preparation <----------+ <-------------+ Parsing <----------- + * | | | Request | | Output | | Output | | + * | | | +-------------+ | | +-----------+ + * | | | | | + * | Application- | One2One- | | controller- | + * | Flow | Bidi | | Stage | + * | | | | | + * | | | | | +-----------+ + * | HttpResponse | | HttpResponse | | Response- | renderer- | ByteString + * v -------------> +-----------------------------------> +-------------> Pipeline +----------> + * | | | | Rendering- | | + * +----------+ +-------------+ Context +-----------+ */ private[http] object HttpServerBluePrint { def apply(settings: ServerSettings, remoteAddress: Option[InetSocketAddress], log: LoggingAdapter)(implicit mat: Materializer): Http.ServerLayer = { @@ -45,15 +63,6 @@ private[http] object HttpServerBluePrint { val ws = websocketSetup val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log) - @volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168 - val oneHundredContinueSource = StreamUtils.oneTimeSource(Source.actorPublisher[OneHundredContinue.type] { - Props { - val actor = new TokenSourceActor(OneHundredContinue) - oneHundredContinueRef = Some(actor.context.self) - actor - }.withDeploy(Deploy.local) - }, errorMsg = "Http.serverLayer is currently not reusable. You need to create a new instance for each materialization.") - def establishAbsoluteUri(requestOutput: RequestOutput): RequestOutput = requestOutput match { case start: RequestStart ⇒ try { @@ -69,10 +78,10 @@ private[http] object HttpServerBluePrint { val requestParsingFlow = Flow[ByteString].transform(() ⇒ // each connection uses a single (private) request parser instance for all its requests // which builds a cache of all header instances seen on that connection - rootParser.createShallowCopy(() ⇒ oneHundredContinueRef).stage).named("rootParser") + rootParser.createShallowCopy().stage).named("rootParser") .map(establishAbsoluteUri) - val requestPreparation = + val requestPreparationFlow = Flow[RequestOutput] .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .via(headAndTailFlow) @@ -96,208 +105,232 @@ private[http] object HttpServerBluePrint { // `buffer` will ensure demand and therefore make sure that completion is reported eagerly. .buffer(1, OverflowStrategy.backpressure) - // we need to make sure that only one element per incoming request is queueing up in front of - // the bypassMerge.bypassInput. Otherwise the rising backpressure against the bypassFanout - // would eventually prevent us from reading the remaining request chunks from the transportIn - val bypass = Flow[RequestOutput].collect { - case r: RequestStart ⇒ r - case m: MessageStartError ⇒ m - } - - val rendererPipeline = + val responseRenderingFlow = Flow[ResponseRenderingContext] - .recover(errorResponseRecovery(log, settings)) .via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer")) .flatMapConcat(ConstantFun.scalaIdentityFunction) .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger")) - BidiFlow.fromGraph(FlowGraph.create(requestParsingFlow, rendererPipeline, oneHundredContinueSource)((_, _, _) ⇒ ()) { implicit b ⇒ - (requestParsing, renderer, oneHundreds) ⇒ - import FlowGraph.Implicits._ + BidiFlow.fromGraph(FlowGraph.create() { implicit b ⇒ + import FlowGraph.Implicits._ - val bypassFanout = b.add(Broadcast[RequestOutput](2).named("bypassFanout")) - val bypassMerge = b.add(new BypassMerge(settings, log)) - val bypassInput = bypassMerge.in0 - val bypassOneHundredContinueInput = bypassMerge.in1 - val bypassApplicationInput = bypassMerge.in2 + // HTTP + val requestParsing = b.add(requestParsingFlow) + val requestPreparation = b.add(requestPreparationFlow) + val responseRendering = b.add(responseRenderingFlow) + val controllerStage = b.add(new ControllerStage(settings, log)) + val csRequestParsingIn = controllerStage.in1 + val csRequestPrepOut = controllerStage.out1 + val csHttpResponseIn = controllerStage.in2 + val csResponseCtxOut = controllerStage.out2 + requestParsing.outlet ~> csRequestParsingIn + csResponseCtxOut ~> responseRendering.inlet + csRequestPrepOut ~> requestPreparation - // HTTP pipeline - requestParsing.outlet ~> bypassFanout.in - bypassMerge.out ~> renderer.inlet - val requestsIn = (bypassFanout.out(0) ~> requestPreparation).outlet + // One2OneBidi + val one2one = b.add(new One2OneBidi[HttpRequest, HttpResponse](settings.pipeliningLimit)) + requestPreparation.outlet ~> one2one.in1 + one2one.out2 ~> csHttpResponseIn - bypassFanout.out(1) ~> bypass ~> bypassInput - oneHundreds ~> bypassOneHundredContinueInput + // Websocket + val http = FlowShape(requestParsing.inlet, responseRendering.outlet) + val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2)) + val switchToWebsocket = b.add(Flow[ResponseRenderingOutput] + .collect { case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken }) + val websocket = b.add(ws.websocketFlow) + val protocolRouter = b.add(WebsocketSwitchRouter) + val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory, log)) + val wsSwitchTokenMerge = b.add(WsSwitchTokenMerge) + switchTokenBroadcast ~> switchToWebsocket ~> wsSwitchTokenMerge.in1 + protocolRouter.out0 ~> http ~> switchTokenBroadcast ~> protocolMerge.in0 + protocolRouter.out1 ~> websocket ~> protocolMerge.in1 + wsSwitchTokenMerge.out ~> protocolRouter.in - val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2)) - renderer.outlet ~> switchTokenBroadcast - val switchSource: Outlet[SwitchToWebsocketToken.type] = - (switchTokenBroadcast ~> - Flow[ResponseRenderingOutput] - .collect { - case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken - }).outlet + // SSL/TLS + val unwrapTls = b.add(Flow[SslTlsInbound] collect { case x: SessionBytes ⇒ x.bytes }) + val wrapTls = b.add(Flow[ByteString].map[SslTlsOutbound](SendBytes)) + unwrapTls ~> wsSwitchTokenMerge.in0 + protocolMerge.out ~> wrapTls - val http = FlowShape(requestParsing.inlet, switchTokenBroadcast.outlet) - - // Websocket pipeline - val websocket = b.add(ws.websocketFlow) - - // protocol routing - val protocolRouter = b.add(WebsocketSwitchRouter) - val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory, log)) - - protocolRouter.out0 ~> http ~> protocolMerge.in0 - protocolRouter.out1 ~> websocket ~> protocolMerge.in1 - - // protocol switching - val wsSwitchTokenMerge = b.add(WsSwitchTokenMerge) - // feed back switch signal to the protocol router - switchSource ~> wsSwitchTokenMerge.in1 - wsSwitchTokenMerge.out ~> protocolRouter.in - - val unwrapTls = b.add(Flow[SslTlsInbound] collect { case x: SessionBytes ⇒ x.bytes }) - val wrapTls = b.add(Flow[ByteString].map[SslTlsOutbound](SendBytes)) - - unwrapTls ~> wsSwitchTokenMerge.in0 - protocolMerge.out ~> wrapTls - - BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest]( - bypassApplicationInput, - wrapTls.outlet, - unwrapTls.inlet, - requestsIn) + BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest]( + one2one.in2, + wrapTls.outlet, + unwrapTls.inlet, + one2one.out1) }) } - class BypassMerge(settings: ServerSettings, log: LoggingAdapter) - extends GraphStage[FanInShape3[MessageStart with RequestOutput, OneHundredContinue.type, HttpResponse, ResponseRenderingContext]] { - private val bypassInput = Inlet[MessageStart with RequestOutput]("bypassInput") - private val oneHundredContinue = Inlet[OneHundredContinue.type]("100continue") - private val applicationInput = Inlet[HttpResponse]("applicationInput") - private val out = Outlet[ResponseRenderingContext]("bypassOut") + class ControllerStage(settings: ServerSettings, log: LoggingAdapter) + extends GraphStage[BidiShape[RequestOutput, RequestOutput, HttpResponse, ResponseRenderingContext]] { + private val requestParsingIn = Inlet[RequestOutput]("requestParsingIn") + private val requestPrepOut = Outlet[RequestOutput]("requestPrepOut") + private val httpResponseIn = Inlet[HttpResponse]("httpResponseIn") + private val responseCtxOut = Outlet[ResponseRenderingContext]("responseCtxOut") + val shape = new BidiShape(requestParsingIn, requestPrepOut, httpResponseIn, responseCtxOut) - override val shape = new FanInShape3(bypassInput, oneHundredContinue, applicationInput, out) + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { + val pullHttpResponseIn = () ⇒ pull(httpResponseIn) + var openRequests = immutable.Queue[RequestStart]() + var oneHundredContinueResponsePending = false + var pullSuppressed = false + var messageEndPending = false - override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { - var requestStart: RequestStart = _ - - setHandler(bypassInput, new InHandler { - override def onPush(): Unit = { - grab(bypassInput) match { + setHandler(requestParsingIn, new InHandler { + def onPush(): Unit = + grab(requestParsingIn) match { case r: RequestStart ⇒ - requestStart = r - pull(applicationInput) - if (r.expect100ContinueResponsePending) - read(oneHundredContinue) { cont ⇒ - emit(out, ResponseRenderingContext(HttpResponse(StatusCodes.Continue))) - requestStart = requestStart.copy(expect100ContinueResponsePending = false) - } - case MessageStartError(status, info) ⇒ finishWithError(status, info) - } - } - override def onUpstreamFinish(): Unit = - requestStart match { - case null ⇒ completeStage() - case r ⇒ requestStart = r.copy(closeRequested = true) + openRequests = openRequests.enqueue(r) + messageEndPending = r.createEntity.isInstanceOf[StreamedEntityCreator[_, _]] + val rs = if (r.expect100Continue) { + oneHundredContinueResponsePending = true + r.copy(createEntity = with100ContinueTrigger(r.createEntity)) + } else r + push(requestPrepOut, rs) + case MessageEnd ⇒ + messageEndPending = false + push(requestPrepOut, MessageEnd) + case MessageStartError(status, info) ⇒ finishWithIllegalRequestError(status, info) + case x ⇒ push(requestPrepOut, x) } + override def onUpstreamFinish() = + if (openRequests.isEmpty) completeStage() + else complete(requestPrepOut) }) - setHandler(applicationInput, new InHandler { - override def onPush(): Unit = { - val response = grab(applicationInput) - // see the comment on [[OneHundredContinue]] for an explanation of the closing logic here (and more) - val close = requestStart.closeRequested || requestStart.expect100ContinueResponsePending - abortReading(oneHundredContinue) - emit(out, ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close), - if (close) () ⇒ completeStage() else pullBypass) + setHandler(requestPrepOut, new OutHandler { + def onPull(): Unit = + if (oneHundredContinueResponsePending) pullSuppressed = true + else pull(requestParsingIn) + override def onDownstreamFinish() = cancel(requestParsingIn) + }) + + setHandler(httpResponseIn, new InHandler { + def onPush(): Unit = { + val response = grab(httpResponseIn) + val requestStart = openRequests.head + openRequests = openRequests.tail + val isEarlyResponse = messageEndPending && openRequests.isEmpty + if (isEarlyResponse && response.status.isSuccess) + log.warning( + """Sending 2xx response before end of request was received... + |Note that the connection will be closed after this response. Also, many clients will not read early responses! + |Consider waiting for the request end before dispatching this response!""".stripMargin) + val close = requestStart.closeRequested || + requestStart.expect100Continue && oneHundredContinueResponsePending || + isClosed(requestParsingIn) && openRequests.isEmpty || + isEarlyResponse + emit(responseCtxOut, ResponseRenderingContext(response, requestStart.method, requestStart.protocol, close), + pullHttpResponseIn) + if (close) complete(responseCtxOut) } + override def onUpstreamFinish() = + if (openRequests.isEmpty && isClosed(requestParsingIn)) completeStage() + else complete(responseCtxOut) override def onUpstreamFailure(ex: Throwable): Unit = ex match { case EntityStreamException(errorInfo) ⇒ // the application has forwarded a request entity stream error to the response stream - finishWithError(StatusCodes.BadRequest, errorInfo) - case _ ⇒ failStage(ex) + finishWithIllegalRequestError(StatusCodes.BadRequest, errorInfo) + + case EntityStreamSizeException(limit, contentLength) ⇒ + val summary = contentLength match { + case Some(cl) ⇒ s"Request Content-Length of $cl bytes exceeds the configured limit of $limit bytes" + case None ⇒ s"Aggregated data length of request entity exceeds the configured limit of $limit bytes" + } + val info = ErrorInfo(summary, "Consider increasing the value of akka.http.server.parsing.max-content-length") + finishWithIllegalRequestError(StatusCodes.RequestEntityTooLarge, info) + + case NonFatal(e) ⇒ + log.error(e, "Internal server error, sending 500 response") + emitErrorResponse(HttpResponse(StatusCodes.InternalServerError)) } }) - def finishWithError(status: StatusCode, info: ErrorInfo): Unit = { + class ResponseCtxOutHandler extends OutHandler { + override def onPull() = {} + override def onDownstreamFinish() = + cancel(httpResponseIn) // we cannot fully completeState() here as the websocket pipeline would not complete properly + } + setHandler(responseCtxOut, new ResponseCtxOutHandler { + override def onPull() = { + pull(httpResponseIn) + // after the initial pull here we only ever pull after having emitted in `onPush` of `httpResponseIn` + setHandler(responseCtxOut, new ResponseCtxOutHandler) + } + }) + + def finishWithIllegalRequestError(status: StatusCode, info: ErrorInfo): Unit = { logParsingError(info withSummaryPrepended s"Illegal request, responding with status '$status'", log, settings.parserSettings.errorLoggingVerbosity) val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary - emit(out, ResponseRenderingContext(HttpResponse(status, entity = msg), closeRequested = true), () ⇒ completeStage()) + emitErrorResponse(HttpResponse(status, entity = msg)) } - setHandler(oneHundredContinue, ignoreTerminateInput) // RK: not sure if this is always correct - setHandler(out, eagerTerminateOutput) + def emitErrorResponse(response: HttpResponse): Unit = + emit(responseCtxOut, ResponseRenderingContext(response, closeRequested = true), () ⇒ complete(responseCtxOut)) - val pullBypass = () ⇒ - if (isClosed(bypassInput)) completeStage() - else { - pull(bypassInput) - requestStart = null + /** + * The `Expect: 100-continue` header has a special status in HTTP. + * It allows the client to send an `Expect: 100-continue` header with the request and then pause request sending + * (i.e. hold back sending the request entity). The server reads the request headers, determines whether it wants to + * accept the request and responds with + * + * - `417 Expectation Failed`, if it doesn't support the `100-continue` expectation + * (or if the `Expect` header contains other, unsupported expectations). + * - a `100 Continue` response, + * if it is ready to accept the request entity and the client should go ahead with sending it + * - a final response (like a 4xx to signal some client-side error + * (e.g. if the request entity length is beyond the configured limit) or a 3xx redirect) + * + * Only if the client receives a `100 Continue` response from the server is it allowed to continue sending the request + * entity. In this case it will receive another response after having completed request sending. + * So this special feature breaks the normal "one request - one response" logic of HTTP! + * It therefore requires special handling in all HTTP stacks (client- and server-side). + * + * For us this means: + * + * - on the server-side: + * After having read a `Expect: 100-continue` header with the request we package up an `HttpRequest` instance and send + * it through to the application. Only when (and if) the application then requests data from the entity stream do we + * send out a `100 Continue` response and continue reading the request entity. + * The application can therefore determine itself whether it wants the client to send the request entity + * by deciding whether to look at the request entity data stream or not. + * If the application sends a response *without* having looked at the request entity the client receives this + * response *instead of* the `100 Continue` response and the server closes the connection afterwards. + * + * - on the client-side: + * If the user adds a `Expect: 100-continue` header to the request we need to hold back sending the entity until + * we've received a `100 Continue` response. + */ + val emit100ContinueResponse = + getAsyncCallback[Unit] { _ ⇒ + oneHundredContinueResponsePending = false + emit(responseCtxOut, ResponseRenderingContext(HttpResponse(StatusCodes.Continue))) + if (pullSuppressed) { + pullSuppressed = false + pull(requestParsingIn) + } } - override def preStart(): Unit = { - pull(bypassInput) - } + def with100ContinueTrigger[T <: ParserOutput](createEntity: EntityCreator[T, RequestEntity]) = + StreamedEntityCreator { + createEntity.compose[Source[T, Unit]] { + _.via(Flow[T].transform(() ⇒ new PushPullStage[T, T] { + private var oneHundredContinueSent = false + def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + def onPull(ctx: Context[T]) = { + if (!oneHundredContinueSent) { + oneHundredContinueSent = true + emit100ContinueResponse.invoke(()) + } + ctx.pull() + } + }).named("expect100continueTrigger")) + } + } } } - def errorResponseRecovery(log: LoggingAdapter, - settings: ServerSettings): PartialFunction[Throwable, ResponseRenderingContext] = { - case EntityStreamSizeException(limit, contentLength) ⇒ - val status = StatusCodes.RequestEntityTooLarge - val summary = contentLength match { - case Some(cl) ⇒ s"Request Content-Length of $cl bytes exceeds the configured limit of $limit bytes" - case None ⇒ s"Aggregated data length of request entity exceeds the configured limit of $limit bytes" - } - val info = ErrorInfo(summary, "Consider increasing the value of akka.http.server.parsing.max-content-length") - logParsingError(info withSummaryPrepended s"Illegal request, responding with status '$status'", - log, settings.parserSettings.errorLoggingVerbosity) - val msg = if (settings.verboseErrorMessages) info.formatPretty else info.summary - ResponseRenderingContext(HttpResponse(status, entity = msg), closeRequested = true) - - case NonFatal(e) ⇒ - log.error(e, "Internal server error, sending 500 response") - ResponseRenderingContext(HttpResponse(StatusCodes.InternalServerError), closeRequested = true) - } - - /** - * The `Expect: 100-continue` header has a special status in HTTP. - * It allows the client to send an `Expect: 100-continue` header with the request and then pause request sending - * (i.e. hold back sending the request entity). The server reads the request headers, determines whether it wants to - * accept the request and responds with - * - * - `417 Expectation Failed`, if it doesn't support the `100-continue` expectation - * (or if the `Expect` header contains other, unsupported expectations). - * - a `100 Continue` response, - * if it is ready to accept the request entity and the client should go ahead with sending it - * - a final response (like a 4xx to signal some client-side error - * (e.g. if the request entity length is beyond the configured limit) or a 3xx redirect) - * - * Only if the client receives a `100 Continue` response from the server is it allowed to continue sending the request - * entity. In this case it will receive another response after having completed request sending. - * So this special feature breaks the normal "one request - one response" logic of HTTP! - * It therefore requires special handling in all HTTP stacks (client- and server-side). - * - * For us this means: - * - * - on the server-side: - * After having read a `Expect: 100-continue` header with the request we package up an `HttpRequest` instance and send - * it through to the application. Only when (and if) the application then requests data from the entity stream do we - * send out a `100 Continue` response and continue reading the request entity. - * The application can therefore determine itself whether it wants the client to send the request entity - * by deciding whether to look at the request entity data stream or not. - * If the application sends a response *without* having looked at the request entity the client receives this - * response *instead of* the `100 Continue` response and the server closes the connection afterwards. - * - * - on the client-side: - * If the user adds a `Expect: 100-continue` header to the request we need to hold back sending the entity until - * we've received a `100 Continue` response. - */ - case object OneHundredContinue - private trait WebsocketSetup { def websocketFlow: Flow[ByteString, ByteString, Any] def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: Materializer): Unit diff --git a/akka-http-core/src/test/resources/reference.conf b/akka-http-core/src/test/resources/reference.conf index 1660c0e30d..4dc03e5ed7 100644 --- a/akka-http-core/src/test/resources/reference.conf +++ b/akka-http-core/src/test/resources/reference.conf @@ -3,5 +3,7 @@ akka { actor { serialize-creators = on serialize-messages = on + default-dispatcher.throughput = 1 } + stream.materializer.debug.fuzzing-mode=off } \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index 4e9d37f71c..dc4ce06105 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -616,6 +616,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") """HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX + |Connection: close |Content-Type: application/octet-stream |Content-Length: 100000 | @@ -690,6 +691,30 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") request.headers should contain(`Remote-Address`(RemoteAddress(theAddress, Some(8080)))) } + "add `Connection: close` to early responses" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 100000 + | + |""") + + val HttpRequest(POST, _, _, entity, _) = expectRequest() + responses.sendNext(HttpResponse(status = StatusCodes.InsufficientStorage)) + + expectResponseWithWipedDate( + """HTTP/1.1 507 Insufficient Storage + |Server: akka-http/test + |Date: XXXX + |Connection: close + |Content-Length: 0 + | + |""") + + netIn.sendComplete() + requests.expectComplete() + netOut.expectComplete() + } + def isDefinedVia = afterWord("is defined via") "support request length verification" which isDefinedVia { diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala index 80f5be196e..540d766f33 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala @@ -7,9 +7,11 @@ package akka.http.scaladsl import java.io.{ BufferedReader, BufferedWriter, InputStreamReader, OutputStreamWriter } import java.net.{ BindException, Socket } import java.util.concurrent.TimeoutException - +import scala.annotation.tailrec +import scala.concurrent.duration._ +import scala.concurrent.{ Await, Future, Promise } +import scala.util.{ Try, Success } import akka.actor.ActorSystem -import akka.event.NoLogging import akka.http.impl.util._ import akka.http.scaladsl.Http.ServerBinding import akka.http.scaladsl.model.HttpEntity._ @@ -25,11 +27,6 @@ import akka.util.ByteString import com.typesafe.config.{ Config, ConfigFactory } import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpec } -import scala.annotation.tailrec -import scala.concurrent.duration._ -import scala.concurrent.{ Await, Future, Promise } -import scala.util.{ Try, Success } - class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" akka.loggers = ["akka.testkit.TestEventListener"] @@ -357,10 +354,8 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { toStrict(response.entity) shouldEqual HttpEntity("yeah") clientOutSub.sendComplete() - serverInSub.request(1) // work-around for #16552 serverIn.expectComplete() serverOutSub.expectCancellation() - clientInSub.request(1) // work-around for #16552 clientIn.expectComplete() binding.foreach(_.unbind()) @@ -397,10 +392,8 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { Await.result(chunkStream2.grouped(1000).runWith(Sink.head), 100.millis) shouldEqual chunks clientOutSub.sendComplete() - serverInSub.request(1) // work-around for #16552 serverIn.expectComplete() serverOutSub.expectCancellation() - clientInSub.request(1) // work-around for #16552 clientIn.expectComplete() connSourceSub.cancel() @@ -430,10 +423,8 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val response = clientIn.expectNext() toStrict(response.entity) shouldEqual HttpEntity("yeah") - serverInSub.request(1) // work-around for #16552 serverIn.expectComplete() serverOutSub.expectCancellation() - clientInSub.request(1) // work-around for #16552 clientIn.expectComplete() connSourceSub.cancel() diff --git a/akka-http-tests/src/test/resources/reference.conf b/akka-http-tests/src/test/resources/reference.conf index 1660c0e30d..33e27263ec 100644 --- a/akka-http-tests/src/test/resources/reference.conf +++ b/akka-http-tests/src/test/resources/reference.conf @@ -3,5 +3,7 @@ akka { actor { serialize-creators = on serialize-messages = on + default-dispatcher.throughput = 1 } + stream.materializer.debug.fuzzing-mode=on } \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala index 17557683c5..44b397b365 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala @@ -46,7 +46,7 @@ class One2OneBidiFlowSpec extends AkkaSpec with ConversionCheckedTripleEquals { outOut.expectError(new One2OneBidiFlow.UnexpectedOutputException(3)) } - "drop surplus output elements" in new Test() { + "fully propagate cancellation" in new Test() { inIn.sendNext(1) inOut.requestNext() should ===(1) @@ -55,6 +55,9 @@ class One2OneBidiFlowSpec extends AkkaSpec with ConversionCheckedTripleEquals { outOut.cancel() outIn.expectCancellation() + + inOut.cancel() + inIn.expectCancellation() } "backpressure the input side if the maximum number of pending output elements has been reached" in { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index b923ebc9d9..61b8824e3f 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -799,7 +799,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends * '''Completes when''' upstream completes * * '''Cancels when''' downstream cancels and substreams cancel - * + * * See also [[Flow.splitAfter]]. */ def splitWhen(p: function.Predicate[Out]): javadsl.Flow[In, Source[Out, Unit], Mat] = diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala index f32b7a4b07..61cbb3df0c 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala @@ -3,15 +3,13 @@ */ package akka.stream.scaladsl +import scala.util.control.NoStackTrace import akka.stream._ import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } -import scala.concurrent.duration.Deadline -import scala.util.control.NoStackTrace - object One2OneBidiFlow { - case class UnexpectedOutputException(element: Any) extends RuntimeException with NoStackTrace + case class UnexpectedOutputException(element: Any) extends RuntimeException(element.toString) with NoStackTrace case object OutputTruncationException extends RuntimeException with NoStackTrace /** @@ -24,8 +22,6 @@ object One2OneBidiFlow { * for every input element. * 3. Backpressures the input side if the maximum number of pending output elements has been reached, * which is given via the ``maxPending`` parameter. You can use -1 to disable this feature. - * 4. Drops surplus output elements, i.e. ones that the inner flow tries to produce after the input stream - * has signalled completion. Note that no error is triggered in this case! */ def apply[I, O](maxPending: Int): BidiFlow[I, I, O, O, Unit] = BidiFlow.fromGraph(new One2OneBidi[I, O](maxPending)) @@ -41,7 +37,7 @@ object One2OneBidiFlow { override def createLogic(effectiveAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pending = 0 - private var pullsSuppressed = 0 + private var pullSuppressed = false setHandler(inIn, new InHandler { override def onPush(): Unit = { @@ -54,7 +50,7 @@ object One2OneBidiFlow { setHandler(inOut, new OutHandler { override def onPull(): Unit = if (pending < maxPending || maxPending == -1) pull(inIn) - else pullsSuppressed += 1 + else pullSuppressed = true override def onDownstreamFinish(): Unit = cancel(inIn) }) @@ -64,8 +60,8 @@ object One2OneBidiFlow { if (pending > 0) { pending -= 1 push(outOut, element) - if (pullsSuppressed > 0) { - pullsSuppressed -= 1 + if (pullSuppressed) { + pullSuppressed = false pull(inIn) } } else throw new UnexpectedOutputException(element) @@ -77,10 +73,7 @@ object One2OneBidiFlow { setHandler(outOut, new OutHandler { override def onPull(): Unit = pull(outIn) - override def onDownstreamFinish(): Unit = { - cancel(outIn) - cancel(inIn) // short-cut to speed up cleanup of upstream - } + override def onDownstreamFinish(): Unit = cancel(outIn) }) } } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index e5cdef8e17..2a0ae4c0c0 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -373,7 +373,11 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Signals that there will be no more elements emitted on the given port. */ - final protected def complete[T](out: Outlet[T]): Unit = interpreter.complete(conn(out)) + final protected def complete[T](out: Outlet[T]): Unit = + getHandler(out) match { + case e: Emitting[_] ⇒ e.addFollowUp(new EmittingCompletion(e.out, e.previous)) + case _ ⇒ interpreter.complete(conn(out)) + } /** * Signals failure through the given port.