Allow a read of length 0 in InputStreamAdapter #28751 (#28759)

This commit is contained in:
Eike Wacker 2020-03-20 11:33:00 +01:00 committed by GitHub
parent 35b4da09dd
commit 630e712b9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 26 deletions

View file

@ -153,7 +153,7 @@ class InputStreamSinkSpec extends StreamSpec(UnboundedMailboxConfig) {
an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, -1, 2)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, -1, 2))
an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, 5)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, 5))
an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(new Array[Byte](0), 0, 1)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(new Array[Byte](0), 0, 1))
an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, 0)) an[IllegalArgumentException] shouldBe thrownBy(inputStream.read(buf, 0, -1))
inputStream.close() inputStream.close()
} }
@ -275,6 +275,18 @@ class InputStreamSinkSpec extends StreamSpec(UnboundedMailboxConfig) {
} }
thrown.getCause should ===(error) thrown.getCause should ===(error)
} }
"a read of length 0 should not request bytes from upstream" in assertAllStagesStopped {
val (probe, inputStream) = TestSource.probe[ByteString].toMat(StreamConverters.asInputStream())(Keep.both).run()
probe.ensureSubscription()
probe.expectRequest()
inputStream.read(new Array[Byte](byteString.size), 0, 0) should ===(0)
probe.expectNoMessage()
inputStream.close()
probe.expectCancellation()
}
} }
} }

View file

@ -148,34 +148,36 @@ private[stream] object InputStreamSinkStage {
override def read(a: Array[Byte], begin: Int, length: Int): Int = { override def read(a: Array[Byte], begin: Int, length: Int): Int = {
require(a.length > 0, "array size must be >= 0") require(a.length > 0, "array size must be >= 0")
require(begin >= 0, "begin must be >= 0") require(begin >= 0, "begin must be >= 0")
require(length > 0, "length must be > 0") require(length >= 0, "length must be >= 0")
require(begin + length <= a.length, "begin + length must be smaller or equal to the array length") require(begin + length <= a.length, "begin + length must be smaller or equal to the array length")
executeIfNotClosed(() => if (length == 0) 0
if (isStageAlive) { else
detachedChunk match { executeIfNotClosed(() =>
case None => if (isStageAlive) {
try { detachedChunk match {
sharedBuffer.poll(readTimeout.toMillis, TimeUnit.MILLISECONDS) match { case None =>
case Data(data) => try {
detachedChunk = Some(data) sharedBuffer.poll(readTimeout.toMillis, TimeUnit.MILLISECONDS) match {
readBytes(a, begin, length) case Data(data) =>
case Finished => detachedChunk = Some(data)
isStageAlive = false readBytes(a, begin, length)
-1 case Finished =>
case Failed(ex) => isStageAlive = false
isStageAlive = false -1
throw new IOException(ex) case Failed(ex) =>
case null => throw new IOException("Timeout on waiting for new data") isStageAlive = false
case Initialized => throw new IllegalStateException("message 'Initialized' must come first") throw new IOException(ex)
case null => throw new IOException("Timeout on waiting for new data")
case Initialized => throw new IllegalStateException("message 'Initialized' must come first")
}
} catch {
case ex: InterruptedException => throw new IOException(ex)
} }
} catch { case Some(_) =>
case ex: InterruptedException => throw new IOException(ex) readBytes(a, begin, length)
} }
case Some(_) => } else -1)
readBytes(a, begin, length)
}
} else -1)
} }
private[this] def readBytes(a: Array[Byte], begin: Int, length: Int): Int = { private[this] def readBytes(a: Array[Byte], begin: Int, length: Int): Int = {