diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index d4bfd64dd4..ba20233ba6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -219,6 +219,44 @@ class SourceSpec extends AkkaSpec { } } + "Unfold Source" must { + val expected = List(9227465, 5702887, 3524578, 2178309, 1346269, 832040, 514229, 317811, 196418, 121393, 75025, 46368, 28657, 17711, 10946, 6765, 4181, 2584, 1597, 987, 610, 377, 233, 144, 89, 55, 34, 21, 13, 8, 5, 3, 2, 1, 1, 0) + + "generate a finite fibonacci sequence" in { + import GraphDSL.Implicits._ + + val source = Source.unfold((0, 1)) { + case (a, _) if a > 10000000 ⇒ None + case (a, b) ⇒ Some((b, a + b) → a) + } + val result = Await.result(source.runFold(List.empty[Int]) { case (xs, x) ⇒ x :: xs }, 1.second) + result should ===(expected) + } + + "generate a finite fibonacci sequence asynchronously" in { + import GraphDSL.Implicits._ + import scala.concurrent.Future + import scala.concurrent.ExecutionContext.Implicits.global + + val source = Source.unfoldAsync((0, 1)) { + case (a, _) if a > 10000000 ⇒ Future.successful(None) + case (a, b) ⇒ Future(Some((b, a + b) → a)) + } + val result = Await.result(source.runFold(List.empty[Int]) { case (xs, x) ⇒ x :: xs }, 1.second) + result should ===(expected) + } + + "generate an infinite fibonacci sequence" in { + import GraphDSL.Implicits._ + + val source = Source.unfoldInf((0, 1)) { + case (a, b) ⇒ (b, a + b) → a + } + val result = Await.result(source.take(36).runFold(List.empty[Int]) { case (xs, x) ⇒ x :: xs }, 1.second) + result should ===(expected) + } + } + "Iterator Source" must { "properly iterate" in { val result = Await.result(Source(() ⇒ Iterator.iterate(false)(!_)).grouped(10).runWith(Sink.head), 1.second) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 3e18d37177..2c5ab5a1dc 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -62,6 +62,9 @@ private[stream] object Stages { val unzip = name("unzip") val concat = name("concat") val repeat = name("repeat") + val unfold = name("unfold") + val unfoldAsync = name("unfoldAsync") + val unfoldInf = name("unfoldInf") val publisherSource = name("publisherSource") val iterableSource = name("iterableSource") diff --git a/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala b/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala new file mode 100644 index 0000000000..0a8d9217e7 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/Unfold.scala @@ -0,0 +1,79 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.stage.{ OutHandler, GraphStageLogic, GraphStage } +import akka.stream._ + +import scala.concurrent.{ ExecutionContext, Future } +import scala.util.{ Failure, Success, Try } + +/** + * Unfold `GraphStage` class + * @param s initial state + * @param f unfold function + * @tparam S state + * @tparam E element + */ +private[akka] class Unfold[S, E](s: S, f: S ⇒ Option[(S, E)]) extends GraphStage[SourceShape[E]] { + + val out: Outlet[E] = Outlet("Unfold") + + override val shape: SourceShape[E] = SourceShape(out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = { + new GraphStageLogic(shape) { + private[this] var state = s + + setHandler(out, new OutHandler { + override def onPull(): Unit = f(state) match { + case None ⇒ complete(out) + case Some((newState, v)) ⇒ { + push(out, v) + state = newState + } + } + }) + } + } +} + +/** + * UnfoldAsync `GraphStage` class + * @param s initial state + * @param f unfold function + * @tparam S state + * @tparam E element + */ +private[akka] class UnfoldAsync[S, E](s: S, f: S ⇒ Future[Option[(S, E)]]) extends GraphStage[SourceShape[E]] { + + val out: Outlet[E] = Outlet("UnfoldAsync") + + override val shape: SourceShape[E] = SourceShape(out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = { + new GraphStageLogic(shape) { + private[this] var state = s + + private[this] var asyncHandler: Function1[Try[Option[(S, E)]], Unit] = _ + + override def preStart() = { + val ac = getAsyncCallback[Try[Option[(S, E)]]] { + case Failure(ex) ⇒ fail(out, ex) + case Success(None) ⇒ complete(out) + case Success(Some((newS, elem))) ⇒ { + push(out, elem) + state = newS + } + } + asyncHandler = ac.invoke + } + + setHandler(out, new OutHandler { + override def onPull(): Unit = + f(state).onComplete(asyncHandler)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) + }) + } + } +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index 47d62eb512..7728ce02e2 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -172,6 +172,26 @@ object Source { def repeat[T](element: T): Source[T, Unit] = new Source(scaladsl.Source.repeat(element)) + /** + * create a `Source` that will unfold a value of type `S` into + * a pair of the next state `S` and output elements of type `E`. + */ + def unfold[S, E](s: S, f: function.Function[S, Option[(S, E)]]): Source[E, Unit] = + new Source(scaladsl.Source.unfold(s)((s: S) ⇒ f.apply(s))) + + /** + * same as unfold, but uses an async function to generate the next state-element tuple. + */ + def unfoldAsync[S, E](s: S, f: function.Function[S, Future[Option[(S, E)]]]): Source[E, Unit] = + new Source(scaladsl.Source.unfoldAsync(s)((s: S) ⇒ f.apply(s))) + + /** + * simpler unfold, for infinite sequences. + */ + def unfoldInf[S, E](s: S, f: function.Function[S, (S, E)]): Source[E, Unit] = { + new Source(scaladsl.Source.unfoldInf(s)((s: S) ⇒ f.apply(s))) + } + /** * Create a `Source` that immediately ends the stream with the `cause` failure to every connected `Sink`. */ diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index 947865f800..7221edb022 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -244,6 +244,64 @@ object Source { shape("RepeatSource"))).mapConcat(ConstantFun.scalaIdentityFunction) } + /** + * create a `Source` that will unfold a value of type `S` into + * a pair of the next state `S` and output elements of type `E`. + * + * for example, all the fibonacci numbers under 10M: + * + * {{{ + * Source.unfold(0 → 1){ + * case (a,_) if a > 10000000 ⇒ None + * case (a,b) ⇒ Some((b → (a + b)) → a) + * } + * }}} + */ + def unfold[S, E](s: S)(f: S ⇒ Option[(S, E)]): Source[E, Unit] = + Source.fromGraph(new Unfold(s, f)).withAttributes(DefaultAttributes.unfold) + + /** + * same as unfold, but uses an async function to generate the next state-element tuple. + * + * async fibonacci example: + * + * {{{ + * Source.unfoldAsync(0 → 1){ + * case (a,_) if a > 10000000 ⇒ Future.successful(None) + * case (a,b) ⇒ Future{ + * Thread.sleep(1000) + * Some((b → (a + b)) → a) + * } + * } + * }}} + */ + def unfoldAsync[S, E](s: S)(f: S ⇒ Future[Option[(S, E)]]): Source[E, Unit] = + Source.fromGraph(new UnfoldAsync(s, f)).withAttributes(DefaultAttributes.unfoldAsync) + + /** + * simpler unfold, for infinite sequences. + * + * {{{ + * Source.unfoldInf(0 → 1){ + * case (a,b) ⇒ (b → (a + b)) → a + * } + * }}} + */ + def unfoldInf[S, E](s: S)(f: S ⇒ (S, E)): Source[E, Unit] = { + Source.fromGraph(GraphDSL.create() { implicit b ⇒ + import GraphDSL.Implicits._ + + val uzip = b.add(UnzipWith(f)) + val cnct = b.add(Concat[S]()) + val init = Source.single(s) + + init ~> cnct ~> uzip.in + cnct <~ uzip.out0 + + SourceShape(uzip.out1) + }).withAttributes(DefaultAttributes.unfoldInf) + } + /** * A `Source` with no elements, i.e. an empty stream that is completed immediately for every connected `Sink`. */