Merge pull request #19143 from hochgi/hochgi-19021-unfold

+str #19021 Add unfold (and variants) generators to Source
This commit is contained in:
Roland Kuhn 2015-12-14 11:43:49 +01:00
commit 5c55180327
5 changed files with 198 additions and 0 deletions

View file

@ -219,6 +219,44 @@ class SourceSpec extends AkkaSpec {
} }
} }
"Unfold Source" must {
val expected = List(9227465, 5702887, 3524578, 2178309, 1346269, 832040, 514229, 317811, 196418, 121393, 75025, 46368, 28657, 17711, 10946, 6765, 4181, 2584, 1597, 987, 610, 377, 233, 144, 89, 55, 34, 21, 13, 8, 5, 3, 2, 1, 1, 0)
"generate a finite fibonacci sequence" in {
import GraphDSL.Implicits._
val source = Source.unfold((0, 1)) {
case (a, _) if a > 10000000 None
case (a, b) Some((b, a + b) a)
}
val result = Await.result(source.runFold(List.empty[Int]) { case (xs, x) x :: xs }, 1.second)
result should ===(expected)
}
"generate a finite fibonacci sequence asynchronously" in {
import GraphDSL.Implicits._
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
val source = Source.unfoldAsync((0, 1)) {
case (a, _) if a > 10000000 Future.successful(None)
case (a, b) Future(Some((b, a + b) a))
}
val result = Await.result(source.runFold(List.empty[Int]) { case (xs, x) x :: xs }, 1.second)
result should ===(expected)
}
"generate an infinite fibonacci sequence" in {
import GraphDSL.Implicits._
val source = Source.unfoldInf((0, 1)) {
case (a, b) (b, a + b) a
}
val result = Await.result(source.take(36).runFold(List.empty[Int]) { case (xs, x) x :: xs }, 1.second)
result should ===(expected)
}
}
"Iterator Source" must { "Iterator Source" must {
"properly iterate" in { "properly iterate" in {
val result = Await.result(Source(() Iterator.iterate(false)(!_)).grouped(10).runWith(Sink.head), 1.second) val result = Await.result(Source(() Iterator.iterate(false)(!_)).grouped(10).runWith(Sink.head), 1.second)

View file

@ -62,6 +62,9 @@ private[stream] object Stages {
val unzip = name("unzip") val unzip = name("unzip")
val concat = name("concat") val concat = name("concat")
val repeat = name("repeat") val repeat = name("repeat")
val unfold = name("unfold")
val unfoldAsync = name("unfoldAsync")
val unfoldInf = name("unfoldInf")
val publisherSource = name("publisherSource") val publisherSource = name("publisherSource")
val iterableSource = name("iterableSource") val iterableSource = name("iterableSource")

View file

@ -0,0 +1,79 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.scaladsl
import akka.stream.stage.{ OutHandler, GraphStageLogic, GraphStage }
import akka.stream._
import scala.concurrent.{ ExecutionContext, Future }
import scala.util.{ Failure, Success, Try }
/**
* Unfold `GraphStage` class
* @param s initial state
* @param f unfold function
* @tparam S state
* @tparam E element
*/
private[akka] class Unfold[S, E](s: S, f: S Option[(S, E)]) extends GraphStage[SourceShape[E]] {
val out: Outlet[E] = Outlet("Unfold")
override val shape: SourceShape[E] = SourceShape(out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = {
new GraphStageLogic(shape) {
private[this] var state = s
setHandler(out, new OutHandler {
override def onPull(): Unit = f(state) match {
case None complete(out)
case Some((newState, v)) {
push(out, v)
state = newState
}
}
})
}
}
}
/**
* UnfoldAsync `GraphStage` class
* @param s initial state
* @param f unfold function
* @tparam S state
* @tparam E element
*/
private[akka] class UnfoldAsync[S, E](s: S, f: S Future[Option[(S, E)]]) extends GraphStage[SourceShape[E]] {
val out: Outlet[E] = Outlet("UnfoldAsync")
override val shape: SourceShape[E] = SourceShape(out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = {
new GraphStageLogic(shape) {
private[this] var state = s
private[this] var asyncHandler: Function1[Try[Option[(S, E)]], Unit] = _
override def preStart() = {
val ac = getAsyncCallback[Try[Option[(S, E)]]] {
case Failure(ex) fail(out, ex)
case Success(None) complete(out)
case Success(Some((newS, elem))) {
push(out, elem)
state = newS
}
}
asyncHandler = ac.invoke
}
setHandler(out, new OutHandler {
override def onPull(): Unit =
f(state).onComplete(asyncHandler)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
})
}
}
}

View file

@ -172,6 +172,26 @@ object Source {
def repeat[T](element: T): Source[T, Unit] = def repeat[T](element: T): Source[T, Unit] =
new Source(scaladsl.Source.repeat(element)) new Source(scaladsl.Source.repeat(element))
/**
* create a `Source` that will unfold a value of type `S` into
* a pair of the next state `S` and output elements of type `E`.
*/
def unfold[S, E](s: S, f: function.Function[S, Option[(S, E)]]): Source[E, Unit] =
new Source(scaladsl.Source.unfold(s)((s: S) f.apply(s)))
/**
* same as unfold, but uses an async function to generate the next state-element tuple.
*/
def unfoldAsync[S, E](s: S, f: function.Function[S, Future[Option[(S, E)]]]): Source[E, Unit] =
new Source(scaladsl.Source.unfoldAsync(s)((s: S) f.apply(s)))
/**
* simpler unfold, for infinite sequences.
*/
def unfoldInf[S, E](s: S, f: function.Function[S, (S, E)]): Source[E, Unit] = {
new Source(scaladsl.Source.unfoldInf(s)((s: S) f.apply(s)))
}
/** /**
* Create a `Source` that immediately ends the stream with the `cause` failure to every connected `Sink`. * Create a `Source` that immediately ends the stream with the `cause` failure to every connected `Sink`.
*/ */

View file

@ -244,6 +244,64 @@ object Source {
shape("RepeatSource"))).mapConcat(ConstantFun.scalaIdentityFunction) shape("RepeatSource"))).mapConcat(ConstantFun.scalaIdentityFunction)
} }
/**
* create a `Source` that will unfold a value of type `S` into
* a pair of the next state `S` and output elements of type `E`.
*
* for example, all the fibonacci numbers under 10M:
*
* {{{
* Source.unfold(0 1){
* case (a,_) if a > 10000000 None
* case (a,b) Some((b (a + b)) a)
* }
* }}}
*/
def unfold[S, E](s: S)(f: S Option[(S, E)]): Source[E, Unit] =
Source.fromGraph(new Unfold(s, f)).withAttributes(DefaultAttributes.unfold)
/**
* same as unfold, but uses an async function to generate the next state-element tuple.
*
* async fibonacci example:
*
* {{{
* Source.unfoldAsync(0 1){
* case (a,_) if a > 10000000 Future.successful(None)
* case (a,b) Future{
* Thread.sleep(1000)
* Some((b (a + b)) a)
* }
* }
* }}}
*/
def unfoldAsync[S, E](s: S)(f: S Future[Option[(S, E)]]): Source[E, Unit] =
Source.fromGraph(new UnfoldAsync(s, f)).withAttributes(DefaultAttributes.unfoldAsync)
/**
* simpler unfold, for infinite sequences.
*
* {{{
* Source.unfoldInf(0 1){
* case (a,b) (b (a + b)) a
* }
* }}}
*/
def unfoldInf[S, E](s: S)(f: S (S, E)): Source[E, Unit] = {
Source.fromGraph(GraphDSL.create() { implicit b
import GraphDSL.Implicits._
val uzip = b.add(UnzipWith(f))
val cnct = b.add(Concat[S]())
val init = Source.single(s)
init ~> cnct ~> uzip.in
cnct <~ uzip.out0
SourceShape(uzip.out1)
}).withAttributes(DefaultAttributes.unfoldInf)
}
/** /**
* A `Source` with no elements, i.e. an empty stream that is completed immediately for every connected `Sink`. * A `Source` with no elements, i.e. an empty stream that is completed immediately for every connected `Sink`.
*/ */