diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java index fdb5cc3437..56c820e32b 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java @@ -1200,4 +1200,39 @@ public class SourceTest extends StreamTest { .join(); assertEquals(Done.getInstance(), completion); } + + @Test + public void mustGenerateAFiniteFibonacciSequenceAsynchronously() { + final List resultList = + Source.unfoldAsync( + Pair.create(0, 1), + (pair) -> { + if (pair.first() > 10000000) { + return CompletableFuture.completedFuture(Optional.empty()); + } else { + return CompletableFuture.supplyAsync( + () -> + Optional.of( + Pair.create( + Pair.create(pair.second(), pair.first() + pair.second()), + pair.first())), + system.dispatcher()); + } + }) + .runFold( + new ArrayList(), + (list, next) -> { + list.add(next); + return list; + }, + system) + .toCompletableFuture() + .join(); + assertEquals( + Arrays.asList( + 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, + 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, + 1346269, 2178309, 3524578, 5702887, 9227465), + resultList); + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala b/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala index f10bef834b..cf8c510ef7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala @@ -4,14 +4,17 @@ package akka.stream.impl -import scala.concurrent.Future -import scala.util.{ Failure, Success, Try } - import akka.annotation.InternalApi +import akka.japi.{ function, Pair } import akka.stream._ import akka.stream.impl.Stages.DefaultAttributes import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } +import java.util.Optional +import java.util.concurrent.CompletionStage +import scala.concurrent.Future +import scala.util.{ Failure, Success, Try } + /** * INTERNAL API */ @@ -24,11 +27,11 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } private[this] var state = s def onPull(): Unit = f(state) match { - case None => complete(out) case Some((newState, v)) => { push(out, v) state = newState } + case None => complete(out) } setHandler(out, this) @@ -46,15 +49,15 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { private[this] var state = s - private[this] var asyncHandler: Function1[Try[Option[(S, E)]], Unit] = _ + private[this] var asyncHandler: Try[Option[(S, E)]] => Unit = _ - override def preStart() = { + override def preStart(): Unit = { val ac = getAsyncCallback[Try[Option[(S, E)]]] { - case Failure(ex) => fail(out, ex) - case Success(None) => complete(out) case Success(Some((newS, elem))) => push(out, elem) state = newS + case Success(None) => complete(out) + case Failure(ex) => fail(out, ex) } asyncHandler = ac.invoke } @@ -64,3 +67,48 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } setHandler(out, this) } } + +/** + * [[UnfoldAsync]] optimized specifically for Java API and `CompletionStage` + * + * INTERNAL API + */ +@InternalApi private[akka] final class UnfoldAsyncJava[S, E]( + s: S, + f: function.Function[S, CompletionStage[Optional[Pair[S, E]]]]) + extends GraphStage[SourceShape[E]] { + val out: Outlet[E] = Outlet("UnfoldAsync.out") + override val shape: SourceShape[E] = SourceShape(out) + override def initialAttributes: Attributes = DefaultAttributes.unfoldAsync + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler { + private[this] var state = s + private[this] var asyncHandler: Try[Optional[Pair[S, E]]] => Unit = _ + + override def preStart(): Unit = { + val ac = getAsyncCallback[Try[Optional[Pair[S, E]]]] { + case Success(maybeValue) => + if (maybeValue.isPresent) { + val pair = maybeValue.get() + push(out, pair.second) + state = pair.first + } else { + complete(out) + } + case Failure(ex) => fail(out, ex) + } + asyncHandler = ac.invoke + } + + def onPull(): Unit = + f(state).handle((r, ex) => { + if (ex != null) { + asyncHandler(Failure(ex)) + } else { + asyncHandler(Success(r)) + } + null + }) + setHandler(out, this) + } +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index b9f5ca936f..65a2f174e9 100755 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -8,7 +8,6 @@ import java.util import java.util.Optional import java.util.concurrent.{ CompletableFuture, CompletionStage } import java.util.function.{ BiFunction, Supplier } - import scala.annotation.unchecked.uncheckedVariance import scala.collection.immutable import scala.compat.java8.FutureConverters._ @@ -26,7 +25,7 @@ import akka.event.{ LogMarker, LoggingAdapter, MarkerLoggingAdapter } import akka.japi.{ function, JavaPartialFunction, Pair, Util } import akka.japi.function.Creator import akka.stream._ -import akka.stream.impl.LinearTraversalBuilder +import akka.stream.impl.{ LinearTraversalBuilder, UnfoldAsyncJava } import akka.util.{ unused, _ } import akka.util.JavaDurationConverters._ import akka.util.ccompat.JavaConverters._ @@ -265,8 +264,7 @@ object Source { * Same as [[unfold]], but uses an async function to generate the next state-element tuple. */ def unfoldAsync[S, E](s: S, f: function.Function[S, CompletionStage[Optional[Pair[S, E]]]]): Source[E, NotUsed] = - new Source(scaladsl.Source.unfoldAsync(s)((s: S) => - f.apply(s).toScala.map(_.asScala.map(_.toScala))(akka.dispatch.ExecutionContexts.parasitic))) + new Source(scaladsl.Source.fromGraph(new UnfoldAsyncJava[S, E](s, f))) /** * Create a `Source` that immediately ends the stream with the `cause` failure to every connected `Sink`.