diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/ByteStringParserSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/ByteStringParserSpec.scala index 7018cb2fdf..2e78189654 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/ByteStringParserSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/ByteStringParserSpec.scala @@ -7,14 +7,14 @@ import akka.stream.impl.io.ByteStringParser import akka.stream.impl.io.ByteStringParser.{ ByteReader, ParseResult, ParseStep } import akka.stream.scaladsl.{ Sink, Source } import akka.stream.stage.GraphStageLogic -import akka.stream.testkit.StreamSpec +import akka.stream.testkit.{ StreamSpec, TestPublisher, TestSubscriber } import akka.stream.{ ActorMaterializer, Attributes, ThrottleMode } import akka.util.ByteString import scala.concurrent.Await import scala.concurrent.duration._ -class ByteStringParserSpec extends StreamSpec() { +class ByteStringParserSpec extends StreamSpec { implicit val materializer = ActorMaterializer() "ByteStringParser" must { @@ -109,6 +109,51 @@ class ByteStringParserSpec extends StreamSpec() { "Aborting processing to avoid infinite cycles. In the unlikely case that the parsing logic needs more recursion, " + "override ParsingLogic.recursionLimit." } + + "complete eagerly" in { + object DummyParser extends ByteStringParser[ByteString] { + def createLogic(inheritedAttributes: Attributes) = new ParsingLogic { + startWith(new ParseStep[ByteString] { + def parse(reader: ByteReader) = ParseResult(Some(reader.takeAll()), this) + }) + } + } + + val in = TestPublisher.probe[ByteString]() + val out = TestSubscriber.probe[ByteString]() + Source.fromPublisher(in).via(DummyParser).runWith(Sink.fromSubscriber(out)) + + out.request(1L) + in.expectRequest() + in.sendNext(ByteString("aha!")) + out.expectNext() + // no new pull + in.sendComplete() + out.expectComplete() + } + + "fail eagerly on truncation" in { + object DummyParser extends ByteStringParser[ByteString] { + def createLogic(inheritedAttributes: Attributes) = new ParsingLogic { + startWith(new ParseStep[ByteString] { + // take more data than there is in first chunk + def parse(reader: ByteReader) = ParseResult(Some(reader.take(5)), this, false) + }) + } + } + + val in = TestPublisher.probe[ByteString]() + val out = TestSubscriber.probe[ByteString]() + Source.fromPublisher(in).via(DummyParser).runWith(Sink.fromSubscriber(out)) + + out.request(1L) + in.expectRequest() + in.sendNext(ByteString("aha!")) + out.expectNoMsg(100.millis) + // no new pull + in.sendComplete() + out.expectError() shouldBe an[IllegalStateException] + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/ByteStringParser.scala b/akka-stream/src/main/scala/akka/stream/impl/io/ByteStringParser.scala index 1c9730626f..52c7b33397 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/ByteStringParser.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/ByteStringParser.scala @@ -127,10 +127,14 @@ import scala.util.control.{ NoStackTrace, NonFatal } } override def onUpstreamFinish(): Unit = { - // If we have no pending pull from downstream, attempt to invoke the parser again. This will handle + // If we have no a pending pull from downstream, attempt to invoke the parser again. This will handle // truncation if necessary, or complete the stage (and maybe a final emit). if (isAvailable(objOut)) doParse() - // Otherwise the pending pull will kick of doParse() + // if we do not have a pending pull, + else if (buffer.isEmpty) { + if (acceptUpstreamFinish) completeStage() + else current.onTruncation() + } } setHandlers(bytesIn, objOut, this)