=str Add specialized version of unfoldAsync for Java. (#31283)

This commit is contained in:
kerr 2022-04-14 17:12:45 +08:00 committed by GitHub
parent 934b5a055a
commit 3a9126be8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 12 deletions

View file

@ -1200,4 +1200,39 @@ public class SourceTest extends StreamTest {
.join();
assertEquals(Done.getInstance(), completion);
}
@Test
public void mustGenerateAFiniteFibonacciSequenceAsynchronously() {
final List<Integer> resultList =
Source.unfoldAsync(
Pair.create(0, 1),
(pair) -> {
if (pair.first() > 10000000) {
return CompletableFuture.completedFuture(Optional.empty());
} else {
return CompletableFuture.supplyAsync(
() ->
Optional.of(
Pair.create(
Pair.create(pair.second(), pair.first() + pair.second()),
pair.first())),
system.dispatcher());
}
})
.runFold(
new ArrayList<Integer>(),
(list, next) -> {
list.add(next);
return list;
},
system)
.toCompletableFuture()
.join();
assertEquals(
Arrays.asList(
0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181,
6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040,
1346269, 2178309, 3524578, 5702887, 9227465),
resultList);
}
}

View file

@ -4,14 +4,17 @@
package akka.stream.impl
import scala.concurrent.Future
import scala.util.{ Failure, Success, Try }
import akka.annotation.InternalApi
import akka.japi.{ function, Pair }
import akka.stream._
import akka.stream.impl.Stages.DefaultAttributes
import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
import java.util.Optional
import java.util.concurrent.CompletionStage
import scala.concurrent.Future
import scala.util.{ Failure, Success, Try }
/**
* INTERNAL API
*/
@ -24,11 +27,11 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
private[this] var state = s
def onPull(): Unit = f(state) match {
case None => complete(out)
case Some((newState, v)) => {
push(out, v)
state = newState
}
case None => complete(out)
}
setHandler(out, this)
@ -46,15 +49,15 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with OutHandler {
private[this] var state = s
private[this] var asyncHandler: Function1[Try[Option[(S, E)]], Unit] = _
private[this] var asyncHandler: Try[Option[(S, E)]] => Unit = _
override def preStart() = {
override def preStart(): Unit = {
val ac = getAsyncCallback[Try[Option[(S, E)]]] {
case Failure(ex) => fail(out, ex)
case Success(None) => complete(out)
case Success(Some((newS, elem))) =>
push(out, elem)
state = newS
case Success(None) => complete(out)
case Failure(ex) => fail(out, ex)
}
asyncHandler = ac.invoke
}
@ -64,3 +67,48 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
setHandler(out, this)
}
}
/**
* [[UnfoldAsync]] optimized specifically for Java API and `CompletionStage`
*
* INTERNAL API
*/
@InternalApi private[akka] final class UnfoldAsyncJava[S, E](
s: S,
f: function.Function[S, CompletionStage[Optional[Pair[S, E]]]])
extends GraphStage[SourceShape[E]] {
val out: Outlet[E] = Outlet("UnfoldAsync.out")
override val shape: SourceShape[E] = SourceShape(out)
override def initialAttributes: Attributes = DefaultAttributes.unfoldAsync
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with OutHandler {
private[this] var state = s
private[this] var asyncHandler: Try[Optional[Pair[S, E]]] => Unit = _
override def preStart(): Unit = {
val ac = getAsyncCallback[Try[Optional[Pair[S, E]]]] {
case Success(maybeValue) =>
if (maybeValue.isPresent) {
val pair = maybeValue.get()
push(out, pair.second)
state = pair.first
} else {
complete(out)
}
case Failure(ex) => fail(out, ex)
}
asyncHandler = ac.invoke
}
def onPull(): Unit =
f(state).handle((r, ex) => {
if (ex != null) {
asyncHandler(Failure(ex))
} else {
asyncHandler(Success(r))
}
null
})
setHandler(out, this)
}
}

View file

@ -8,7 +8,6 @@ import java.util
import java.util.Optional
import java.util.concurrent.{ CompletableFuture, CompletionStage }
import java.util.function.{ BiFunction, Supplier }
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable
import scala.compat.java8.FutureConverters._
@ -26,7 +25,7 @@ import akka.event.{ LogMarker, LoggingAdapter, MarkerLoggingAdapter }
import akka.japi.{ function, JavaPartialFunction, Pair, Util }
import akka.japi.function.Creator
import akka.stream._
import akka.stream.impl.LinearTraversalBuilder
import akka.stream.impl.{ LinearTraversalBuilder, UnfoldAsyncJava }
import akka.util.{ unused, _ }
import akka.util.JavaDurationConverters._
import akka.util.ccompat.JavaConverters._
@ -265,8 +264,7 @@ object Source {
* Same as [[unfold]], but uses an async function to generate the next state-element tuple.
*/
def unfoldAsync[S, E](s: S, f: function.Function[S, CompletionStage[Optional[Pair[S, E]]]]): Source[E, NotUsed] =
new Source(scaladsl.Source.unfoldAsync(s)((s: S) =>
f.apply(s).toScala.map(_.asScala.map(_.toScala))(akka.dispatch.ExecutionContexts.parasitic)))
new Source(scaladsl.Source.fromGraph(new UnfoldAsyncJava[S, E](s, f)))
/**
* Create a `Source` that immediately ends the stream with the `cause` failure to every connected `Sink`.