+str: Implement PrefixAndTail as a GraphStage

This commit is contained in:
Endre Sándor Varga 2015-12-15 12:40:31 +01:00
parent 72e3dc84de
commit a5d29f2459
10 changed files with 281 additions and 109 deletions

View file

@ -239,11 +239,11 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
setHandler(shape.outlet, new OutHandler {
setHandler(shape.out, new OutHandler {
override def onPull(): Unit = {
completeStage()
// This cannot be propagated now since the stage is already closed
push(shape.outlet, -1)
push(shape.out, -1)
}
})

View file

@ -8,7 +8,7 @@ import scala.concurrent.{ Future, Await }
import scala.concurrent.duration._
import scala.util.Try
import scala.util.control.NoStackTrace
import akka.stream.{ Attributes, ActorMaterializer, ActorMaterializerSettings }
import akka.stream._
import org.reactivestreams.Subscriber
import akka.stream.testkit._
import akka.stream.testkit.Utils._
@ -90,6 +90,48 @@ class FlowPrefixAndTailSpec extends AkkaSpec {
subscriber.expectSubscriptionAndComplete()
}
"throw if tail is attempted to be materialized twice" in assertAllStagesStopped {
val futureSink = newHeadSink
val fut = Source(1 to 2).prefixAndTail(1).runWith(futureSink)
val (takes, tail) = Await.result(fut, 3.seconds)
takes should be(Seq(1))
val subscriber1 = TestSubscriber.probe[Int]()
tail.to(Sink(subscriber1)).run()
val subscriber2 = TestSubscriber.probe[Int]()
tail.to(Sink(subscriber2)).run()
subscriber2.expectSubscriptionAndError().getMessage should ===("Tail Source cannot be materialized more than once.")
subscriber1.requestNext(2).expectComplete()
}
"signal error if substream has been not subscribed in time" in assertAllStagesStopped {
val tightTimeoutMaterializer =
ActorMaterializer(ActorMaterializerSettings(system)
.withSubscriptionTimeoutSettings(
StreamSubscriptionTimeoutSettings(StreamSubscriptionTimeoutTerminationMode.cancel, 500.millisecond)))
val futureSink = newHeadSink
val fut = Source(1 to 2).prefixAndTail(1).runWith(futureSink)(tightTimeoutMaterializer)
val (takes, tail) = Await.result(fut, 3.seconds)
takes should be(Seq(1))
val subscriber = TestSubscriber.probe[Int]()
Thread.sleep(1000)
tail.to(Sink(subscriber)).run()(tightTimeoutMaterializer)
subscriber.expectSubscriptionAndError().getMessage should ===("Tail Source has not been materialized in 500 milliseconds.")
}
"shut down main stage if substream is empty, even when not subscribed" in assertAllStagesStopped {
val futureSink = newHeadSink
val fut = Source.single(1).prefixAndTail(1).runWith(futureSink)
val (takes, tail) = Await.result(fut, 3.seconds)
takes should be(Seq(1))
}
"handle onError when no substream open" in assertAllStagesStopped {
val publisher = TestPublisher.manualProbe[Int]()
val subscriber = TestSubscriber.manualProbe[(immutable.Seq[Int], Source[Int, _])]()

View file

@ -110,7 +110,7 @@ class GraphMatValueSpec extends AkkaSpec {
val foldFlow: Flow[Int, Int, Future[Int]] = Flow.fromGraph(GraphDSL.create(Sink.fold[Int, Int](0)(_ + _)) {
implicit builder
fold
FlowShape(fold.inlet, builder.materializedValue.mapAsync(4)(identity).outlet)
FlowShape(fold.in, builder.materializedValue.mapAsync(4)(identity).outlet)
})
Await.result(Source(1 to 10).via(foldFlow).runWith(Sink.head), 3.seconds) should ===(55)

View file

@ -38,22 +38,20 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) {
implicit val dispatcher = system.dispatcher
implicit val materializer = ActorMaterializer(settings)
"groupBy" must {
"groupBy and splitwhen" must {
"timeout and cancel substream publishers when no-one subscribes to them after some time (time them out)" in assertAllStagesStopped {
val publisherProbe = TestPublisher.manualProbe[Int]()
val publisherProbe = TestPublisher.probe[Int]()
val publisher = Source(publisherProbe).groupBy(3, _ % 3).lift(_ % 3).runWith(Sink.publisher(false))
val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]()
publisher.subscribe(subscriber)
val upstreamSubscription = publisherProbe.expectSubscription()
val downstreamSubscription = subscriber.expectSubscription()
downstreamSubscription.request(100)
upstreamSubscription.sendNext(1)
upstreamSubscription.sendNext(2)
upstreamSubscription.sendNext(3)
publisherProbe.sendNext(1)
publisherProbe.sendNext(2)
publisherProbe.sendNext(3)
val (_, s1) = subscriber.expectNext()
// should not break normal usage
@ -79,42 +77,38 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) {
val f = s3.runWith(Sink.head).recover { case _: SubscriptionTimeoutException "expected" }
Await.result(f, 300.millis) should equal("expected")
upstreamSubscription.sendComplete()
publisherProbe.sendComplete()
}
"timeout and stop groupBy parent actor if none of the substreams are actually consumed" in assertAllStagesStopped {
val publisherProbe = TestPublisher.manualProbe[Int]()
val publisherProbe = TestPublisher.probe[Int]()
val publisher = Source(publisherProbe).groupBy(2, _ % 2).lift(_ % 2).runWith(Sink.publisher(false))
val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]()
publisher.subscribe(subscriber)
val upstreamSubscription = publisherProbe.expectSubscription()
val downstreamSubscription = subscriber.expectSubscription()
downstreamSubscription.request(100)
upstreamSubscription.sendNext(1)
upstreamSubscription.sendNext(2)
upstreamSubscription.sendNext(3)
upstreamSubscription.sendComplete()
publisherProbe.sendNext(1)
publisherProbe.sendNext(2)
publisherProbe.sendNext(3)
publisherProbe.sendComplete()
val (_, s1) = subscriber.expectNext()
val (_, s2) = subscriber.expectNext()
}
"not timeout and cancel substream publishers when they have been subscribed to" in {
val publisherProbe = TestPublisher.manualProbe[Int]()
val publisherProbe = TestPublisher.probe[Int]()
val publisher = Source(publisherProbe).groupBy(2, _ % 2).lift(_ % 2).runWith(Sink.publisher(false))
val subscriber = TestSubscriber.manualProbe[(Int, Source[Int, _])]()
publisher.subscribe(subscriber)
val upstreamSubscription = publisherProbe.expectSubscription()
val downstreamSubscription = subscriber.expectSubscription()
downstreamSubscription.request(100)
upstreamSubscription.sendNext(1)
upstreamSubscription.sendNext(2)
publisherProbe.sendNext(1)
publisherProbe.sendNext(2)
val (_, s1) = subscriber.expectNext()
// should not break normal usage
@ -136,8 +130,8 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) {
s2Sub.request(100)
s2SubscriberProbe.expectNext(2)
s1Sub.request(100)
upstreamSubscription.sendNext(3)
upstreamSubscription.sendNext(4)
publisherProbe.sendNext(3)
publisherProbe.sendNext(4)
s1SubscriberProbe.expectNext(3)
s2SubscriberProbe.expectNext(4)
}

View file

@ -277,7 +277,6 @@ private[akka] object ActorProcessorFactory {
val settings = materializer.effectiveSettings(att)
op match {
case GroupBy(maxSubstreams, f, _) (GroupByProcessorImpl.props(settings, maxSubstreams, f), ())
case PrefixAndTail(n, _) (PrefixAndTailImpl.props(settings, n), ())
case Split(d, _) (SplitWhereProcessorImpl.props(settings, d), ())
case DirectProcessor(p, m) throw new AssertionError("DirectProcessor cannot end up in ActorProcessorFactory")
}

View file

@ -1,66 +0,0 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl
import scala.collection.immutable
import akka.stream.ActorMaterializerSettings
import akka.stream.scaladsl.Source
import akka.actor.{ Deploy, Props }
/**
* INTERNAL API
*/
private[akka] object PrefixAndTailImpl {
def props(settings: ActorMaterializerSettings, takeMax: Int): Props =
Props(new PrefixAndTailImpl(settings, takeMax)).withDeploy(Deploy.local)
}
/**
* INTERNAL API
*/
private[akka] class PrefixAndTailImpl(_settings: ActorMaterializerSettings, val takeMax: Int)
extends MultiStreamOutputProcessor(_settings) {
import MultiStreamOutputProcessor._
var taken = immutable.Vector.empty[Any]
var left = takeMax
val take = TransferPhase(primaryInputs.NeedsInputOrComplete && primaryOutputs.NeedsDemand) { ()
if (primaryInputs.inputsDepleted) emitEmptyTail()
else {
val elem = primaryInputs.dequeueInputElement()
taken :+= elem
left -= 1
if (left <= 0) {
if (primaryInputs.inputsDepleted) emitEmptyTail()
else emitNonEmptyTail()
}
}
}
def streamTailPhase(substream: SubstreamOutput) = TransferPhase(primaryInputs.NeedsInput && substream.NeedsDemand) { ()
substream.enqueueOutputElement(primaryInputs.dequeueInputElement())
}
val takeEmpty = TransferPhase(primaryOutputs.NeedsDemand) { ()
if (primaryInputs.inputsDepleted) emitEmptyTail()
else emitNonEmptyTail()
}
def emitEmptyTail(): Unit = {
primaryOutputs.enqueueOutputElement((taken, Source.empty))
nextPhase(completedPhase)
}
def emitNonEmptyTail(): Unit = {
val substreamOutput = createSubstreamOutput()
val substreamFlow = Source(substreamOutput) // substreamOutput is a Publisher
primaryOutputs.enqueueOutputElement((taken, substreamFlow))
primaryOutputs.complete()
nextPhase(streamTailPhase(substreamOutput))
}
if (takeMax > 0) initialPhase(1, take) else initialPhase(1, takeEmpty)
}

View file

@ -208,10 +208,6 @@ private[stream] object Stages {
override def withAttributes(attributes: Attributes) = copy(attributes = attributes)
}
final case class PrefixAndTail(n: Int, attributes: Attributes = prefixAndTail) extends StageModule {
override def withAttributes(attributes: Attributes) = copy(attributes = attributes)
}
final case class Split(p: Any SplitDecision, attributes: Attributes = split) extends StageModule {
override def withAttributes(attributes: Attributes) = copy(attributes = attributes)
}

View file

@ -214,6 +214,7 @@ private[stream] object ActorGraphInterpreter {
private var downstreamCompleted = false
// when upstream failed before we got the exposed publisher
private var upstreamFailed: Option[Throwable] = None
private var upstreamCompleted: Boolean = false
private def onNext(elem: Any): Unit = {
downstreamDemand -= 1
@ -221,21 +222,21 @@ private[stream] object ActorGraphInterpreter {
}
private def complete(): Unit = {
if (!downstreamCompleted) {
downstreamCompleted = true
// No need to complete if had already been cancelled, or we closed earlier
if (!(upstreamCompleted || downstreamCompleted)) {
upstreamCompleted = true
if (exposedPublisher ne null) exposedPublisher.shutdown(None)
if (subscriber ne null) tryOnComplete(subscriber)
}
}
def fail(e: Throwable): Unit = {
if (!downstreamCompleted) {
downstreamCompleted = true
// No need to fail if had already been cancelled, or we closed earlier
if (!(downstreamCompleted || upstreamCompleted)) {
upstreamCompleted = true
upstreamFailed = Some(e)
if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e))
if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e)
} else if (exposedPublisher == null && upstreamFailed.isEmpty) {
// fail called before the exposed publisher arrived, we must store it and fail when we're first able to
upstreamFailed = Some(e)
}
}
@ -258,6 +259,7 @@ private[stream] object ActorGraphInterpreter {
if (subscriber eq null) {
subscriber = sub
tryOnSubscribe(subscriber, new BoundarySubscription(actor, id))
if (GraphInterpreter.Debug) println(s"${interpreter.Name} subscribe subscriber=$sub")
} else
rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}")
}
@ -267,7 +269,8 @@ private[stream] object ActorGraphInterpreter {
case _: Some[_]
publisher.shutdown(upstreamFailed)
case _
exposedPublisher = publisher
if (upstreamCompleted) publisher.shutdown(None)
else exposedPublisher = publisher
}
}
@ -322,6 +325,7 @@ private[stream] class ActorGraphInterpreter(
private val outputs = Array.tabulate(shape.outlets.size)(new ActorOutputBoundary(self, _))
private var subscribesPending = inputs.length
private var publishersPending = outputs.length
/*
* Limits the number of events processed by the interpreter before scheduling
@ -398,21 +402,28 @@ private[stream] class ActorGraphInterpreter(
case SubscribePending(id: Int)
outputs(id).subscribePending()
case ExposedPublisher(id, publisher)
publishersPending -= 1
outputs(id).exposedPublisher(publisher)
}
private def waitShutdown: Receive = {
case ExposedPublisher(id, publisher)
outputs(id).exposedPublisher(publisher)
publishersPending -= 1
if (canShutDown) context.stop(self)
case OnSubscribe(_, sub)
tryCancel(sub)
subscribesPending -= 1
if (subscribesPending == 0) context.stop(self)
if (canShutDown) context.stop(self)
case ReceiveTimeout
tryAbort(new TimeoutException("Streaming actor has been already stopped processing (normally), but not all of its " +
s"inputs have been subscribed in [${settings.subscriptionTimeoutSettings.timeout}}]. Aborting actor now."))
s"inputs or outputs have been subscribed in [${settings.subscriptionTimeoutSettings.timeout}}]. Aborting actor now."))
case _ // Ignore, there is nothing to do anyway
}
private def canShutDown: Boolean = subscribesPending + publishersPending == 0
private def runBatch(): Unit = {
try {
val effectiveLimit = {
@ -425,7 +436,7 @@ private[stream] class ActorGraphInterpreter(
interpreter.execute(effectiveLimit)
if (interpreter.isCompleted) {
// Cannot stop right away if not completely subscribed
if (subscribesPending == 0) context.stop(self)
if (canShutDown) context.stop(self)
else {
context.become(waitShutdown)
context.setReceiveTimeout(settings.subscriptionTimeoutSettings.timeout)

View file

@ -3,7 +3,10 @@
*/
package akka.stream.impl.fusing
import java.util.concurrent.atomic.AtomicReference
import akka.stream._
import akka.stream.impl.SubscriptionTimeoutException
import akka.stream.stage._
import akka.stream.scaladsl._
import akka.stream.actor.ActorSubscriberMessage
@ -11,8 +14,13 @@ import akka.stream.actor.ActorSubscriberMessage._
import akka.stream.actor.ActorPublisherMessage
import akka.stream.actor.ActorPublisherMessage._
import java.{ util ju }
import scala.collection.immutable
import scala.concurrent._
import scala.concurrent.duration.FiniteDuration
/**
* INTERNAL API
*/
final class FlattenMerge[T, M](breadth: Int) extends GraphStage[FlowShape[Graph[SourceShape[T], M], T]] {
private val in = Inlet[Graph[SourceShape[T], M]]("flatten.in")
private val out = Outlet[T]("flatten.out")
@ -218,3 +226,191 @@ private[fusing] object StreamOfStreams {
}
}
}
/**
* INTERNAL API
*/
object PrefixAndTail {
sealed trait MaterializationState
case object NotMaterialized extends MaterializationState
case object AlreadyMaterialized extends MaterializationState
case object TimedOut extends MaterializationState
case object NormalCompletion extends MaterializationState
case class FailureCompletion(ex: Throwable) extends MaterializationState
trait TailInterface[T] {
def pushSubstream(elem: T): Unit
def completeSubstream(): Unit
def failSubstream(ex: Throwable)
}
final class TailSource[T](
timeout: FiniteDuration,
register: TailInterface[T] Unit,
pullParent: Unit Unit,
cancelParent: Unit Unit) extends GraphStage[SourceShape[T]] {
val out: Outlet[T] = Outlet("Tail.out")
val materializationState = new AtomicReference[MaterializationState](NotMaterialized)
override val shape: SourceShape[T] = SourceShape(out)
private final class TailSourceLogic(_shape: Shape) extends GraphStageLogic(_shape) with OutHandler with TailInterface[T] {
setHandler(out, this)
override def preStart(): Unit = {
materializationState.getAndSet(AlreadyMaterialized) match {
case AlreadyMaterialized
failStage(new IllegalStateException("Tail Source cannot be materialized more than once."))
case TimedOut
// Already detached from parent
failStage(new SubscriptionTimeoutException(s"Tail Source has not been materialized in $timeout."))
case NormalCompletion
// Already detached from parent
completeStage()
case FailureCompletion(ex)
// Already detached from parent
failStage(ex)
case NotMaterialized
register(this)
}
}
private val onParentPush = getAsyncCallback[T](push(out, _))
private val onParentFinish = getAsyncCallback[Unit](_ completeStage())
private val onParentFailure = getAsyncCallback[Throwable](failStage)
override def pushSubstream(elem: T): Unit = onParentPush.invoke(elem)
override def completeSubstream(): Unit = onParentFinish.invoke(())
override def failSubstream(ex: Throwable): Unit = onParentFailure.invoke(ex)
override def onPull(): Unit = pullParent()
override def onDownstreamFinish(): Unit = cancelParent()
}
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TailSourceLogic(shape)
}
}
/**
* INTERNAL API
*/
final class PrefixAndTail[T](n: Int) extends GraphStage[FlowShape[T, (immutable.Seq[T], Source[T, Unit])]] {
val in: Inlet[T] = Inlet("PrefixAndTail.in")
val out: Outlet[(immutable.Seq[T], Source[T, Unit])] = Outlet("PrefixAndTail.out")
override val shape: FlowShape[T, (immutable.Seq[T], Source[T, Unit])] = FlowShape(in, out)
override def initialAttributes = Attributes.name("PrefixAndTail")
private final class PrefixAndTailLogic(_shape: Shape) extends TimerGraphStageLogic(_shape) with OutHandler with InHandler {
import PrefixAndTail._
private var left = if (n < 0) 0 else n
private var builder = Vector.newBuilder[T]
private var tailSource: TailSource[T] = null
private var tail: TailInterface[T] = null
builder.sizeHint(left)
private var pendingCompletion: MaterializationState = null
private val SubscriptionTimer = "SubstreamSubscriptionTimer"
private val onSubstreamPull = getAsyncCallback[Unit](_ pull(in))
private val onSubstreamFinish = getAsyncCallback[Unit](_ completeStage())
private val onSubstreamRegister = getAsyncCallback[TailInterface[T]] { tailIf
tail = tailIf
cancelTimer(SubscriptionTimer)
pendingCompletion match {
case NormalCompletion
tail.completeSubstream()
completeStage()
case FailureCompletion(ex)
tail.failSubstream(ex)
completeStage()
case _
}
}
override protected def onTimer(timerKey: Any): Unit =
if (tailSource.materializationState.compareAndSet(NotMaterialized, TimedOut)) completeStage()
private def prefixComplete = builder eq null
private def waitingSubstreamRegistration = tail eq null
private def openSubstream(): Source[T, Unit] = {
val timeout = ActorMaterializer.downcast(interpreter.materializer).settings.subscriptionTimeoutSettings.timeout
tailSource = new TailSource[T](timeout, onSubstreamRegister.invoke, onSubstreamPull.invoke, onSubstreamFinish.invoke)
scheduleOnce(SubscriptionTimer, timeout)
builder = null
Source.fromGraph(tailSource)
}
// Needs to keep alive if upstream completes but substream has been not yet materialized
override def keepGoingAfterAllPortsClosed: Boolean = true
override def onPush(): Unit = {
if (prefixComplete) {
tail.pushSubstream(grab(in))
} else {
builder += grab(in)
left -= 1
if (left == 0) {
push(out, (builder.result(), openSubstream()))
complete(out)
} else pull(in)
}
}
override def onPull(): Unit = {
if (left == 0) {
push(out, (Nil, openSubstream()))
complete(out)
} else pull(in)
}
override def onUpstreamFinish(): Unit = {
if (!prefixComplete) {
// This handles the unpulled out case as well
emit(out, (builder.result, Source.empty), () completeStage())
} else {
if (waitingSubstreamRegistration) {
// Detach if possible.
// This allows this stage to complete without waiting for the substream to be materialized, since that
// is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be
// able to signal completion normally soon.
if (tailSource.materializationState.compareAndSet(NotMaterialized, NormalCompletion)) completeStage()
else pendingCompletion = NormalCompletion
} else {
tail.completeSubstream()
completeStage()
}
}
}
override def onUpstreamFailure(ex: Throwable): Unit = {
if (prefixComplete) {
if (waitingSubstreamRegistration) {
// Detach if possible.
// This allows this stage to complete without waiting for the substream to be materialized, since that
// is empty anyway. If it is already being registered (state was not NotMaterialized) then we will be
// able to signal completion normally soon.
if (tailSource.materializationState.compareAndSet(NotMaterialized, FailureCompletion(ex))) failStage(ex)
else pendingCompletion = FailureCompletion(ex)
} else {
tail.failSubstream(ex)
completeStage()
}
} else failStage(ex)
}
override def onDownstreamFinish(): Unit = {
if (!prefixComplete) completeStage()
// Otherwise substream is open, ignore
}
setHandler(in, this)
setHandler(out, this)
}
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new PrefixAndTailLogic(shape)
}

View file

@ -910,7 +910,7 @@ trait FlowOps[+Out, +Mat] {
* '''Cancels when''' downstream cancels or substream cancels
*/
def prefixAndTail[U >: Out](n: Int): Repr[(immutable.Seq[Out], Source[U, Unit])] =
deprecatedAndThen(PrefixAndTail(n))
via(new PrefixAndTail[Out](n))
/**
* This operation demultiplexes the incoming stream into separate output