diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 4b47319a79..0998f47d06 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -89,6 +89,12 @@ private[http] object HttpServerBluePrint { case _ ⇒ BidiFlow.identity } + /** + * Two state stage, either transforms an incoming RequestOutput into a HttpRequest with strict entity and then pushes + * that (the "idle" inHandler) or creates a HttpRequest with a streamed entity and switch to a state which will push + * incoming chunks into the streaming entity until end of request is reached (the StreamedEntityCreator case in create + * entity). + */ final class PrepareRequests(settings: ServerSettings) extends GraphStage[FlowShape[RequestOutput, HttpRequest]] { val in = Inlet[RequestOutput]("RequestStartThenRunIgnore.in") val out = Outlet[HttpRequest]("RequestStartThenRunIgnore.out") @@ -96,6 +102,7 @@ private[http] object HttpServerBluePrint { override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { val remoteAddress = inheritedAttributes.get[HttpAttributes.RemoteAddress].flatMap(_.address) + var upstreamPullWaiting = false val idle = new InHandler { def onPush(): Unit = grab(in) match { @@ -112,31 +119,59 @@ private[http] object HttpServerBluePrint { throw new IllegalStateException(s"unexpected element of type ${other.getClass}") } } - setHandler(in, idle) + + setIdleHandlers() + + def setIdleHandlers() { + setHandler(in, idle) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + pull(in) + } + }) + if (upstreamPullWaiting) { + upstreamPullWaiting = false + pull(in) + } + } + def createEntity(creator: EntityCreator[RequestOutput, RequestEntity]): RequestEntity = creator match { case StrictEntityCreator(entity) ⇒ entity - case StreamedEntityCreator(creator) ⇒ - val entitySource = new SubSourceOutlet[RequestOutput]("EntitySource") - entitySource.setHandler(new OutHandler { - def onPull(): Unit = pull(in) - }) - setHandler(in, new InHandler { - def onPush(): Unit = grab(in) match { - case MessageEnd ⇒ - entitySource.complete() - setHandler(in, idle) - case x ⇒ entitySource.push(x) - } - override def onUpstreamFinish(): Unit = completeStage() - }) - creator(Source.fromGraph(entitySource.source)) + case StreamedEntityCreator(creator) ⇒ streamRequestEntity(creator) } - setHandler(out, new OutHandler { - override def onPull(): Unit = pull(in) - }) + def streamRequestEntity(creator: (Source[ParserOutput.RequestOutput, NotUsed]) => RequestEntity): RequestEntity = { + // stream the request entity until we reach the end of it + val entitySource = new SubSourceOutlet[RequestOutput]("EntitySource") + entitySource.setHandler(new OutHandler { + def onPull(): Unit = { + pull(in) + } + }) + setHandler(in, new InHandler { + def onPush(): Unit = { + grab(in) match { + case MessageEnd ⇒ + entitySource.complete() + setIdleHandlers() + + case x ⇒ entitySource.push(x) + } + } + override def onUpstreamFinish(): Unit = completeStage() + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + // remember this until we are done with the entity + // so we can pass it downstream at that point + upstreamPullWaiting = true + } + }) + creator(Source.fromGraph(entitySource.source)) + } + } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/PrepareRequestsSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/PrepareRequestsSpec.scala new file mode 100644 index 0000000000..a7d10b9e71 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/PrepareRequestsSpec.scala @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2016 Typesafe Inc. + */ +package akka.http.impl.engine.server + +import akka.http.impl.engine.parsing.ParserOutput +import akka.http.impl.engine.parsing.ParserOutput.{ StrictEntityCreator, EntityStreamError, EntityChunk, StreamedEntityCreator } +import akka.http.impl.engine.server.HttpServerBluePrint.PrepareRequests +import akka.http.scaladsl.model._ +import akka.http.scaladsl.settings.ServerSettings +import akka.stream.{ Attributes, ActorMaterializer } +import akka.stream.scaladsl.{ Sink, Source, Flow } +import akka.stream.testkit.{ TestSubscriber, TestPublisher } +import akka.testkit.AkkaSpec +import akka.util.ByteString +import org.scalatest.{ Matchers, WordSpec } +import scala.concurrent.duration._ + +class PrepareRequestsSpec extends AkkaSpec { + + "The PrepareRequest stage" should { + + "not fail when there is demand from both streamed entity consumption and regular flow" in { + implicit val materializer = ActorMaterializer() + // covers bug #19623 where a reply before the streamed + // body has been consumed causes pull/push twice + val inProbe = TestPublisher.manualProbe[ParserOutput.RequestOutput]() + val upstreamProbe = TestSubscriber.manualProbe[HttpRequest]() + + val stage = Flow.fromGraph(new PrepareRequests(ServerSettings(system))) + + Source.fromPublisher(inProbe) + .via(stage) + .to(Sink.fromSubscriber(upstreamProbe)) + .withAttributes(Attributes.inputBuffer(1, 1)) + .run() + + val upstreamSub = upstreamProbe.expectSubscription() + val inSub = inProbe.expectSubscription() + + // let request with streamed entity through + upstreamSub.request(1) + inSub.expectRequest(1) + inSub.sendNext(ParserOutput.RequestStart( + HttpMethods.GET, + Uri("http://example.com/"), + HttpProtocols.`HTTP/1.1`, + List(), + StreamedEntityCreator[ParserOutput, RequestEntity] { entityChunks ⇒ + val chunks = entityChunks.collect { + case EntityChunk(chunk) ⇒ chunk + case EntityStreamError(info) ⇒ throw EntityStreamException(info) + } + HttpEntity.Chunked(ContentTypes.`application/octet-stream`, HttpEntity.limitableChunkSource(chunks)) + }, + true, + false)) + + val request = upstreamProbe.expectNext() + + // and subscribe to it's streamed entity + val entityProbe = TestSubscriber.manualProbe[ByteString]() + request.entity.dataBytes.to(Sink.fromSubscriber(entityProbe)) + .withAttributes(Attributes.inputBuffer(1, 1)) + .run() + + val entitySub = entityProbe.expectSubscription() + + // the bug happens when both the client has signalled demand + // and the the streamed entity has + upstreamSub.request(1) + entitySub.request(1) + + // then comes the next chunk from the actual request + inSub.expectRequest(1) + + // bug would fail stream here with exception + upstreamProbe.expectNoMsg(100.millis) + + inSub.sendNext(ParserOutput.EntityChunk(HttpEntity.ChunkStreamPart(ByteString("abc")))) + entityProbe.expectNext() + entitySub.request(1) + inSub.sendNext(ParserOutput.MessageEnd) + entityProbe.expectComplete() + + // the rest of the test covers the saved pull + // that should go downstream when the streamed entity + // has reached it's end + inSub.expectRequest(1) + inSub.sendNext(ParserOutput.RequestStart( + HttpMethods.GET, + Uri("http://example.com/"), + HttpProtocols.`HTTP/1.1`, + List(), + StrictEntityCreator(HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("body"))), + true, + false)) + + upstreamProbe.expectNext() + + } + } + +}