diff --git a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala index 185f895c80..d21da058ec 100644 --- a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala @@ -167,7 +167,7 @@ object HttpEntity { */ def apply(contentType: ContentType, chunks: Publisher[ByteString], materializer: FlowMaterializer): Chunked = Chunked(contentType, Flow(chunks).collect[ChunkStreamPart] { - case b: ByteString if b.nonEmpty => Chunk(b) + case b: ByteString if b.nonEmpty ⇒ Chunk(b) }.toPublisher(materializer)) } diff --git a/akka-http-core/src/main/scala/akka/http/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/parsing/HttpMessageParser.scala index 86e499a7f6..35427ccb17 100644 --- a/akka-http-core/src/main/scala/akka/http/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/parsing/HttpMessageParser.scala @@ -127,6 +127,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut } else { val offset = bodyStart + remainingBodyBytes.toInt emit(ParserOutput.EntityPart(input.slice(bodyStart, offset))) + emit(ParserOutput.MessageEnd) if (isLastMessage) terminate() else startNewMessage(input, offset) } @@ -142,6 +143,7 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut val lastChunk = if (extension.isEmpty && headers.isEmpty) HttpEntity.LastChunk else HttpEntity.LastChunk(extension, headers) emit(ParserOutput.EntityChunk(lastChunk)) + emit(ParserOutput.MessageEnd) if (isLastMessage) terminate() else startNewMessage(input, lineEnd) case header if headerCount < settings.maxHeaderCount ⇒ diff --git a/akka-http-core/src/main/scala/akka/http/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/parsing/ParserOutput.scala index 8e62e5a8c0..c6921fc745 100644 --- a/akka-http-core/src/main/scala/akka/http/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/parsing/ParserOutput.scala @@ -37,6 +37,8 @@ private[http] object ParserOutput { createEntity: Publisher[ResponseOutput] ⇒ HttpEntity, closeAfterResponseCompletion: Boolean) extends MessageStart with ResponseOutput + case object MessageEnd extends MessageOutput + final case class EntityPart(data: ByteString) extends MessageOutput final case class EntityChunk(chunk: HttpEntity.ChunkStreamPart) extends MessageOutput diff --git a/akka-http-core/src/main/scala/akka/http/server/HttpServerPipeline.scala b/akka-http-core/src/main/scala/akka/http/server/HttpServerPipeline.scala index 0ba396f560..2c1750131c 100644 --- a/akka-http-core/src/main/scala/akka/http/server/HttpServerPipeline.scala +++ b/akka-http-core/src/main/scala/akka/http/server/HttpServerPipeline.scala @@ -42,7 +42,11 @@ private[http] class HttpServerPipeline(settings: ServerSettings, val requestPublisher = Flow(tcpConn.inputStream) .transform(rootParser.copyWith(warnOnIllegalHeader)) - .splitWhen(_.isInstanceOf[MessageStart]) + // this will create extra single element `[MessageEnd]` substreams, that will + // be filtered out by the above `collect` for the applicationBypass and the + // below `collect` for the actual request handling + // TODO: replace by better combinator, maybe `splitAfter(_ == MessageEnd)`? + .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x == MessageEnd) .headAndTail(materializer) .tee(applicationBypassSubscriber) .collect { diff --git a/akka-http-core/src/main/scala/akka/http/util/package.scala b/akka-http-core/src/main/scala/akka/http/util/package.scala index e0cf456e63..79a8f4696d 100644 --- a/akka-http-core/src/main/scala/akka/http/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/util/package.scala @@ -5,8 +5,6 @@ package akka.http import language.implicitConversions -import java.net.InetSocketAddress -import java.nio.channels.ServerSocketChannel import java.nio.charset.Charset import com.typesafe.config.Config import org.reactivestreams.Publisher diff --git a/akka-http-core/src/test/scala/akka/http/server/HttpServerPipelineSpec.scala b/akka-http-core/src/test/scala/akka/http/server/HttpServerPipelineSpec.scala new file mode 100644 index 0000000000..3440e61c87 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/server/HttpServerPipelineSpec.scala @@ -0,0 +1,364 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http +package server + +import akka.http.model.HttpEntity.{ LastChunk, Chunk, ChunkStreamPart } + +import scala.concurrent.duration._ + +import akka.event.NoLogging +import akka.http.model.headers.Host +import akka.http.model._ +import akka.http.util._ +import akka.stream.io.StreamTcp +import akka.stream.testkit.{ AkkaSpec, StreamTestKit } +import akka.stream.{ FlowMaterializer, MaterializerSettings } +import akka.util.ByteString +import org.scalatest._ + +import scala.concurrent.duration.FiniteDuration + +class HttpServerPipelineSpec extends AkkaSpec with Matchers with BeforeAndAfterAll with Inside { + val materializerSettings = MaterializerSettings(dispatcher = "akka.test.stream-dispatcher") + val materializer = FlowMaterializer(materializerSettings) + + "The server implementation should" should { + "deliver an empty request as soon as all headers are received" in new TestSetup { + send("""GET / HTTP/1.1 + |Host: example.com + | + |""".stripMarginWithNewline("\r\n")) + + expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) + } + "deliver a request as soon as all headers are received" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNoMsg(50.millis) + + send("abcdef") + dataProbe.expectNext(ByteString("abcdef")) + + send("ghijk") + dataProbe.expectNext(ByteString("ghijk")) + dataProbe.expectNoMsg(50.millis) + } + } + "deliver an error as soon as a parsing error occurred" in pendingUntilFixed(new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + | + |""".stripMarginWithNewline("\r\n")) + + requests.expectError() + }) + "report a invalid Chunked stream" in pendingUntilFixed(new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |6 + |abcdef + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(Chunk(ByteString("abcdef"))) + dataProbe.expectNoMsg(50.millis) + + send("3ghi\r\n") // missing "\r\n" after the number of bytes + dataProbe.expectError() + requests.expectError() + } + }) + + "deliver the request entity as it comes in strictly for an immediately completed Strict entity" in pendingUntilFixed(new TestSetup { // broken because of #15686 + send("""POST /strict HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdefghijkl""".stripMarginWithNewline("\r\n")) + + expectRequest shouldEqual + HttpRequest( + method = HttpMethods.POST, + uri = "http://example.com/strict", + headers = List(Host("example.com")), + entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) + }) + "deliver the request entity as it comes in for a Default entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdef""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(ByteString("abcdef")) + + send("ghijk") + dataProbe.expectNext(ByteString("ghijk")) + dataProbe.expectNoMsg(50.millis) + } + } + "deliver the request entity as it comes in for a chunked entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |6 + |abcdef + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(Chunk(ByteString("abcdef"))) + + send("3\r\nghi\r\n") + dataProbe.expectNext(Chunk(ByteString("ghi"))) + dataProbe.expectNoMsg(50.millis) + } + } + + "deliver the second message properly after a Strict entity" in pendingUntilFixed(new TestSetup { // broken because of #15686 + send("""POST /strict HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdefghijkl""".stripMarginWithNewline("\r\n")) + + expectRequest shouldEqual + HttpRequest( + method = HttpMethods.POST, + uri = "http://example.com/strict", + headers = List(Host("example.com")), + entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("abcdefghijkl"))) + + send("""POST /next-strict HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |mnopqrstuvwx""".stripMarginWithNewline("\r\n")) + + expectRequest shouldEqual + HttpRequest( + method = HttpMethods.POST, + uri = "http://example.com/next-strict", + headers = List(Host("example.com")), + entity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, ByteString("mnopqrstuvwx"))) + }) + "deliver the second message properly after a Default entity" in pendingUntilFixed(new TestSetup { // broken because of #15686 + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdef""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(ByteString("abcdef")) + + send("ghij") + dataProbe.expectNext(ByteString("ghij")) + + send("kl") + dataProbe.expectNext(ByteString("kl")) + dataProbe.expectComplete() + } + + send("""POST /next-strict HTTP/1.1 + |Host: example.com + |Content-Length: 5 + | + |abcde""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) ⇒ + data shouldEqual (ByteString("abcde")) + } + }) + "deliver the second message properly after a Chunked entity" in pendingUntilFixed(new TestSetup { // broken because of #15686 + send("""POST /chunked HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |6 + |abcdef + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(Chunk(ByteString("abcdef"))) + + send("3\r\nghi\r\n") + dataProbe.expectNext(ByteString("ghi")) + dataProbe.expectNoMsg(50.millis) + + send("0\r\n\r\n") + dataProbe.expectNext(LastChunk) + dataProbe.expectComplete() + } + + send("""POST /next-strict HTTP/1.1 + |Host: example.com + |Content-Length: 5 + | + |abcde""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Strict(_, data), _) ⇒ + data shouldEqual (ByteString("abcde")) + } + }) + + "close the request entity stream when the entity is complete for a Default entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdef""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(ByteString("abcdef")) + + send("ghijkl") + dataProbe.expectNext(ByteString("ghijkl")) + dataProbe.expectComplete() + } + } + "close the request entity stream when the entity is complete for a Chunked entity" in new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |6 + |abcdef + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(Chunk(ByteString("abcdef"))) + dataProbe.expectNoMsg(50.millis) + + send("0\r\n\r\n") + dataProbe.expectNext(LastChunk) + dataProbe.expectComplete() + } + } + + "report a truncated entity stream on the entity data stream and the main stream for a Default entity" in pendingUntilFixed(new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Content-Length: 12 + | + |abcdef""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ByteString] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(ByteString("abcdef")) + dataProbe.expectNoMsg(50.millis) + + closeNetworkInput() + dataProbe.expectError() + } + }) + "report a truncated entity stream on the entity data stream and the main stream for a Chunked entity" in pendingUntilFixed(new TestSetup { + send("""POST / HTTP/1.1 + |Host: example.com + |Transfer-Encoding: chunked + | + |6 + |abcdef + |""".stripMarginWithNewline("\r\n")) + + inside(expectRequest) { + case HttpRequest(HttpMethods.POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ + val dataProbe = StreamTestKit.SubscriberProbe[ChunkStreamPart] + data.subscribe(dataProbe) + val sub = dataProbe.expectSubscription() + sub.request(10) + dataProbe.expectNext(Chunk(ByteString("abcdef"))) + dataProbe.expectNoMsg(50.millis) + + closeNetworkInput() + dataProbe.expectError() + } + }) + } + + class TestSetup { + val netIn = StreamTestKit.PublisherProbe[ByteString] + val netOut = StreamTestKit.SubscriberProbe[ByteString] + val tcpConnection = StreamTcp.IncomingTcpConnection(null, netIn, netOut) + + val pipeline = new HttpServerPipeline(ServerSettings(system), materializer, NoLogging) + val Http.IncomingConnection(_, requestsIn, responsesOut) = pipeline(tcpConnection) + + val netInSub = netIn.expectSubscription() + + val requests = StreamTestKit.SubscriberProbe[HttpRequest] + val responses = StreamTestKit.PublisherProbe[HttpResponse] + requestsIn.subscribe(requests) + val requestsSub = requests.expectSubscription() + responses.subscribe(responsesOut) + val responsesSub = responses.expectSubscription() + + def expectRequest: HttpRequest = { + requestsSub.request(1) + requests.expectNext() + } + def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max) + + def send(data: ByteString): Unit = netInSub.sendNext(data) + def send(data: String): Unit = send(ByteString(data, "ASCII")) + + def closeNetworkInput(): Unit = netInSub.sendComplete() + } +}