diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala index 3025589c2d..652c018565 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala @@ -153,7 +153,7 @@ class InputStreamSinkSpec extends StreamSpec(UnboundedMailboxConfig) { an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, -1, 2)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, 5)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(new Array[Byte](0), 0, 1)) - an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, 0)) + an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, -1)) inputStream.close() } @@ -275,6 +275,18 @@ class InputStreamSinkSpec extends StreamSpec(UnboundedMailboxConfig) { } thrown.getCause should ===(error) } + + "a read of length 0 should not request bytes from upstream" in assertAllStagesStopped { + val (probe, inputStream) = TestSource.probe[ByteString].toMat(StreamConverters.asInputStream())(Keep.both).run() + probe.ensureSubscription() + probe.expectRequest() + + inputStream.read(new Array[Byte](byteString.size), 0, 0) should ===(0) + probe.expectNoMessage() + + inputStream.close() + probe.expectCancellation() + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala index 2651e45528..211b18a92a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala @@ -148,34 +148,36 @@ private[stream] object InputStreamSinkStage { override def read(a: Array[Byte], begin: Int, length: Int): Int = { require(a.length > 0, "array size must be >= 0") require(begin >= 0, "begin must be >= 0") - require(length > 0, "length must be > 0") + require(length >= 0, "length must be >= 0") require(begin + length <= a.length, "begin + length must be smaller or equal to the array length") - executeIfNotClosed(() => - if (isStageAlive) { - detachedChunk match { - case None => - try { - sharedBuffer.poll(readTimeout.toMillis, TimeUnit.MILLISECONDS) match { - case Data(data) => - detachedChunk = Some(data) - readBytes(a, begin, length) - case Finished => - isStageAlive = false - -1 - case Failed(ex) => - isStageAlive = false - throw new IOException(ex) - case null => throw new IOException("Timeout on waiting for new data") - case Initialized => throw new IllegalStateException("message 'Initialized' must come first") + if (length == 0) 0 + else + executeIfNotClosed(() => + if (isStageAlive) { + detachedChunk match { + case None => + try { + sharedBuffer.poll(readTimeout.toMillis, TimeUnit.MILLISECONDS) match { + case Data(data) => + detachedChunk = Some(data) + readBytes(a, begin, length) + case Finished => + isStageAlive = false + -1 + case Failed(ex) => + isStageAlive = false + throw new IOException(ex) + case null => throw new IOException("Timeout on waiting for new data") + case Initialized => throw new IllegalStateException("message 'Initialized' must come first") + } + } catch { + case ex: InterruptedException => throw new IOException(ex) } - } catch { - case ex: InterruptedException => throw new IOException(ex) - } - case Some(_) => - readBytes(a, begin, length) - } - } else -1) + case Some(_) => + readBytes(a, begin, length) + } + } else -1) } private[this] def readBytes(a: Array[Byte], begin: Int, length: Int): Int = {