diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index 144e8942d3..04f29d4cbd 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -15,6 +15,8 @@ import akka.stream.testkit._ import akka.NotUsed import akka.testkit.EventFilter import scala.collection.immutable +import java.util +import java.util.stream.BaseStream class SourceSpec extends StreamSpec with DefaultTimeout { @@ -365,6 +367,55 @@ class SourceSpec extends StreamSpec with DefaultTimeout { } } + "close the underlying stream when completed" in { + @volatile var closed = false + + final class EmptyStream[A] extends BaseStream[A, EmptyStream[A]] { + override def unordered(): EmptyStream[A] = this + override def sequential(): EmptyStream[A] = this + override def parallel(): EmptyStream[A] = this + override def isParallel: Boolean = false + + override def spliterator(): util.Spliterator[A] = ??? + override def onClose(closeHandler: Runnable): EmptyStream[A] = ??? + + override def iterator(): util.Iterator[A] = new util.Iterator[A] { + override def next(): A = ??? + override def hasNext: Boolean = false + } + + override def close(): Unit = closed = true + } + + Await.ready(StreamConverters.fromJavaStream(() ⇒ new EmptyStream[Unit]).runWith(Sink.ignore), 3.seconds) + + closed should ===(true) + } + + "close the underlying stream when failed" in { + @volatile var closed = false + + final class FailingStream[A] extends BaseStream[A, FailingStream[A]] { + override def unordered(): FailingStream[A] = this + override def sequential(): FailingStream[A] = this + override def parallel(): FailingStream[A] = this + override def isParallel: Boolean = false + + override def spliterator(): util.Spliterator[A] = ??? + override def onClose(closeHandler: Runnable): FailingStream[A] = ??? + + override def iterator(): util.Iterator[A] = new util.Iterator[A] { + override def next(): A = throw new RuntimeException("ouch") + override def hasNext: Boolean = true + } + + override def close(): Unit = closed = true + } + + Await.ready(StreamConverters.fromJavaStream(() ⇒ new FailingStream[Unit]).runWith(Sink.ignore), 3.seconds) + + closed should ===(true) + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala b/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala new file mode 100644 index 0000000000..6293828e18 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala @@ -0,0 +1,40 @@ +package akka.stream.impl + +import akka.stream._ +import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } +import akka.annotation.InternalApi + +/** Internal API */ +@InternalApi +private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]](open: () ⇒ java.util.stream.BaseStream[T, S]) + extends GraphStage[SourceShape[T]] { + + val out: Outlet[T] = Outlet("JavaStreamSource") + override val shape: SourceShape[T] = SourceShape(out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler { + private[this] var stream: java.util.stream.BaseStream[T, S] = _ + private[this] var iter: java.util.Iterator[T] = _ + + setHandler(out, this) + + override def preStart(): Unit = { + stream = open() + iter = stream.iterator() + } + + override def postStop(): Unit = { + if (stream ne null) + stream.close() + } + + override def onPull(): Unit = { + if (iter.hasNext) { + push(out, iter.next()) + } else { + complete(out) + } + } + } +} diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala index 4cad4229c9..478eeb5c10 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala @@ -191,9 +191,6 @@ object StreamConverters { * You can use [[Source.async]] to create asynchronous boundaries between synchronous Java ``Stream`` * and the rest of flow. */ - def fromJavaStream[T, S <: java.util.stream.BaseStream[T, S]](stream: () ⇒ java.util.stream.BaseStream[T, S]): Source[T, NotUsed] = { - import scala.collection.JavaConverters._ - Source.fromIterator(() ⇒ stream().iterator().asScala).withAttributes(DefaultAttributes.fromJavaStream) - } - + def fromJavaStream[T, S <: java.util.stream.BaseStream[T, S]](stream: () ⇒ java.util.stream.BaseStream[T, S]): Source[T, NotUsed] = + Source.fromGraph(new JavaStreamSource[T, S](stream)).withAttributes(DefaultAttributes.fromJavaStream) }