+str: Support for arbitrary, fusable graph processing stages (octopus)

This commit is contained in:
Endre Sándor Varga 2015-08-19 15:22:02 +02:00
parent 530bd13d2a
commit 19c20dfb20
7 changed files with 2022 additions and 1 deletions

View file

@ -0,0 +1,236 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
import java.util.concurrent.TimeoutException
import akka.actor.{ ActorSystem, Cancellable, Scheduler }
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.testkit.{ TestPublisher, AkkaSpec, TestSubscriber }
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.stream.testkit.Utils._
class ActorGraphInterpreterSpec extends AkkaSpec {
implicit val mat = ActorMaterializer()
"ActorGraphInterpreter" must {
"be able to interpret a simple identity graph stage" in assertAllStagesStopped {
val identity = new GraphStages.Identity[Int]
Await.result(
Source(1 to 100).via(identity).grouped(200).runWith(Sink.head),
3.seconds) should ===(1 to 100)
}
"be able to interpret a simple bidi stage" in assertAllStagesStopped {
val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
val in1 = Inlet[Int]("in1")
val in2 = Inlet[Int]("in2")
val out1 = Outlet[Int]("out1")
val out2 = Outlet[Int]("out2")
val shape = BidiShape(in1, out1, in2, out2)
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in1, new InHandler {
override def onPush(): Unit = push(out1, grab(in1))
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = push(out2, grab(in2))
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
}
override def toString = "IdentityBidi"
}
val identity = BidiFlow.wrap(identityBidi).join(Flow[Int].map { x x })
Await.result(
Source(1 to 10).via(identity).grouped(100).runWith(Sink.head),
3.seconds) should ===(1 to 10)
}
"be able to interpret a rotated identity bidi stage" in assertAllStagesStopped {
// This is a "rotated" identity BidiStage, as it loops back upstream elements
// to its upstream, and loops back downstream elementd to its downstream.
val rotatedBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
val in1 = Inlet[Int]("in1")
val in2 = Inlet[Int]("in2")
val out1 = Outlet[Int]("out1")
val out2 = Outlet[Int]("out2")
val shape = BidiShape(in1, out1, in2, out2)
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in1, new InHandler {
override def onPush(): Unit = push(out2, grab(in1))
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = push(out1, grab(in2))
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
}
override def toString = "IdentityBidi"
}
val takeAll = Flow[Int].grouped(200).toMat(Sink.head)(Keep.right)
val (f1, f2) = FlowGraph.closed(takeAll, takeAll)(Keep.both) { implicit b
(out1, out2)
import FlowGraph.Implicits._
val bidi = b.add(rotatedBidi)
Source(1 to 10) ~> bidi.in1
out2 <~ bidi.out2
bidi.in2 <~ Source(1 to 100)
bidi.out1 ~> out1
}.run()
Await.result(f1, 3.seconds) should ===(1 to 100)
Await.result(f2, 3.seconds) should ===(1 to 10)
}
"be able to implement a timeout bidiStage" in {
class IdleTimeout[I, O](
val system: ActorSystem,
val timeout: FiniteDuration) extends GraphStage[BidiShape[I, I, O, O]] {
val in1 = Inlet[I]("in1")
val in2 = Inlet[O]("in2")
val out1 = Outlet[I]("out1")
val out2 = Outlet[O]("out2")
val shape = BidiShape(in1, out1, in2, out2)
override def toString = "IdleTimeout"
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var timerCancellable: Option[Cancellable] = None
private var nextDeadline: Deadline = Deadline.now + timeout
setHandler(in1, new InHandler {
override def onPush(): Unit = {
onActivity()
push(out1, grab(in1))
}
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = {
onActivity()
push(out2, grab(in2))
}
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
private def onActivity(): Unit = nextDeadline = Deadline.now + timeout
private def onTimerTick(): Unit =
if (nextDeadline.isOverdue())
failStage(new TimeoutException(s"No reads or writes happened in $timeout."))
override def preStart(): Unit = {
super.preStart()
val checkPeriod = timeout / 8
val callback = getAsyncCallback[Unit]((_) onTimerTick())
import system.dispatcher
timerCancellable = Some(system.scheduler.schedule(timeout, checkPeriod)(callback.invoke(())))
}
override def postStop(): Unit = {
super.postStop()
timerCancellable.foreach(_.cancel())
}
}
}
val upWrite = TestPublisher.probe[String]()
val upRead = TestSubscriber.probe[Int]()
val downWrite = TestPublisher.probe[Int]()
val downRead = TestSubscriber.probe[String]()
FlowGraph.closed() { implicit b
import FlowGraph.Implicits._
val timeoutStage = b.add(new IdleTimeout[String, Int](system, 2.seconds))
Source(upWrite) ~> timeoutStage.in1; timeoutStage.out1 ~> Sink(downRead)
Sink(upRead) <~ timeoutStage.out2; timeoutStage.in2 <~ Source(downWrite)
}.run()
// Request enough for the whole test
upRead.request(100)
downRead.request(100)
upWrite.sendNext("DATA1")
downRead.expectNext("DATA1")
Thread.sleep(1500)
downWrite.sendNext(1)
upRead.expectNext(1)
Thread.sleep(1500)
upWrite.sendNext("DATA2")
downRead.expectNext("DATA2")
Thread.sleep(1000)
downWrite.sendNext(2)
upRead.expectNext(2)
upRead.expectNoMsg(500.millis)
val error1 = upRead.expectError()
val error2 = downRead.expectError()
error1.isInstanceOf[TimeoutException] should be(true)
error1.getMessage should be("No reads or writes happened in 2 seconds.")
error2 should ===(error1)
upWrite.expectCancellation()
downWrite.expectCancellation()
}
}
}

View file

@ -0,0 +1,457 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
import akka.stream._
import akka.stream.impl.fusing.GraphInterpreterSpec.TestSetup
import akka.stream.stage.{ InHandler, OutHandler, GraphStage, GraphStageLogic }
import akka.stream.testkit.AkkaSpec
import GraphInterpreter._
import scala.collection.immutable
class GraphInterpreterSpec extends AkkaSpec {
import GraphInterpreterSpec._
import GraphStages._
"GraphInterpreter" must {
// Reusable components
val identity = new Identity[Int]
val detacher = new Detacher[Int]
val zip = new Zip[Int, String]
val bcast = new Broadcast[Int](2)
val merge = new Merge[Int](2)
val balance = new Balance[Int](2)
"implement identity" in new TestSetup {
val source = UpstreamProbe[Int]("source")
val sink = DownstreamProbe[Int]("sink")
builder(identity)
.connect(source, identity.in)
.connect(identity.out, sink)
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1)))
}
"implement chained identity" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink = new DownstreamProbe[Int]("sink")
// Constructing an assembly by hand and resolving ambiguities
val assembly = GraphAssembly(
stages = Array(identity, identity),
ins = Array(identity.in, identity.in, null),
inOwners = Array(0, 1, -1),
outs = Array(null, identity.out, identity.out),
outOwners = Array(-1, 0, 1))
manualInit(assembly)
interpreter.attachDownstreamBoundary(2, sink)
interpreter.attachUpstreamBoundary(0, source)
interpreter.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1)))
}
"implement detacher stage" in new TestSetup {
val source = UpstreamProbe[Int]("source")
val sink = DownstreamProbe[Int]("sink")
builder(detacher)
.connect(source, detacher.in)
.connect(detacher.out, sink)
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source)))
// Source waits
source.onNext(2)
lastEvents() should ===(Set.empty)
// "pushAndPull"
sink.requestOne()
lastEvents() should ===(Set(OnNext(sink, 2), RequestOne(source)))
// Sink waits
sink.requestOne()
lastEvents() should ===(Set.empty)
// "pushAndPull"
source.onNext(3)
lastEvents() should ===(Set(OnNext(sink, 3), RequestOne(source)))
}
"implement Zip" in new TestSetup {
val source1 = new UpstreamProbe[Int]("source1")
val source2 = new UpstreamProbe[String]("source2")
val sink = new DownstreamProbe[(Int, String)]("sink")
builder(zip)
.connect(source1, zip.in0)
.connect(source2, zip.in1)
.connect(zip.out, sink)
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))
source1.onNext(42)
lastEvents() should ===(Set.empty)
source2.onNext("Meaning of life")
lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life"))))
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))
}
"implement Broadcast" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink1 = new DownstreamProbe[Int]("sink1")
val sink2 = new DownstreamProbe[Int]("sink2")
builder(bcast)
.connect(source, bcast.in)
.connect(bcast.out(0), sink1)
.connect(bcast.out(1), sink2)
.init()
lastEvents() should ===(Set.empty)
sink1.requestOne()
lastEvents() should ===(Set.empty)
sink2.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink1, 1), OnNext(sink2, 1)))
}
"implement broadcast-zip" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink = new DownstreamProbe[(Int, Int)]("sink")
val zip = new Zip[Int, Int]
builder(zip, bcast)
.connect(source, bcast.in)
.connect(bcast.out(0), zip.in0)
.connect(bcast.out(1), zip.in1)
.connect(zip.out, sink)
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, (1, 1))))
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(2)
lastEvents() should ===(Set(OnNext(sink, (2, 2))))
}
"implement zip-broadcast" in new TestSetup {
val source1 = new UpstreamProbe[Int]("source1")
val source2 = new UpstreamProbe[Int]("source2")
val sink1 = new DownstreamProbe[(Int, Int)]("sink")
val sink2 = new DownstreamProbe[(Int, Int)]("sink2")
val zip = new Zip[Int, Int]
val bcast = new Broadcast[(Int, Int)](2)
builder(bcast, zip)
.connect(source1, zip.in0)
.connect(source2, zip.in1)
.connect(zip.out, bcast.in)
.connect(bcast.out(0), sink1)
.connect(bcast.out(1), sink2)
.init()
lastEvents() should ===(Set.empty)
sink1.requestOne()
lastEvents() should ===(Set.empty)
sink2.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))
source1.onNext(1)
lastEvents() should ===(Set.empty)
source2.onNext(2)
lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2))))
}
"implement merge" in new TestSetup {
val source1 = new UpstreamProbe[Int]("source1")
val source2 = new UpstreamProbe[Int]("source2")
val sink = new DownstreamProbe[Int]("sink")
builder(merge)
.connect(source1, merge.in(0))
.connect(source2, merge.in(1))
.connect(merge.out, sink)
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))
source1.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source1)))
source2.onNext(2)
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(OnNext(sink, 2), RequestOne(source2)))
sink.requestOne()
lastEvents() should ===(Set.empty)
source2.onNext(3)
lastEvents() should ===(Set(OnNext(sink, 3), RequestOne(source2)))
sink.requestOne()
lastEvents() should ===(Set.empty)
source1.onNext(4)
lastEvents() should ===(Set(OnNext(sink, 4), RequestOne(source1)))
}
"implement balance" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink1 = new DownstreamProbe[Int]("sink1")
val sink2 = new DownstreamProbe[Int]("sink2")
builder(balance)
.connect(source, balance.in)
.connect(balance.out(0), sink1)
.connect(balance.out(1), sink2)
.init()
lastEvents() should ===(Set.empty)
sink1.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
sink2.requestOne()
lastEvents() should ===(Set.empty)
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink1, 1), RequestOne(source)))
source.onNext(2)
lastEvents() should ===(Set(OnNext(sink2, 2)))
}
"implement bidi-stage" in pending
"implement non-divergent cycle" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink = new DownstreamProbe[Int]("sink")
builder(merge, balance)
.connect(source, merge.in(0))
.connect(merge.out, balance.in)
.connect(balance.out(0), sink)
.connect(balance.out(1), merge.in(1))
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source)))
// Token enters merge-balance cycle and gets stuck
source.onNext(2)
lastEvents() should ===(Set(RequestOne(source)))
// Unstuck it
sink.requestOne()
lastEvents() should ===(Set(OnNext(sink, 2)))
}
"implement divergent cycle" in new TestSetup {
val source = new UpstreamProbe[Int]("source")
val sink = new DownstreamProbe[Int]("sink")
builder(detacher, balance, merge)
.connect(source, merge.in(0))
.connect(merge.out, balance.in)
.connect(balance.out(0), sink)
.connect(balance.out(1), detacher.in)
.connect(detacher.out, merge.in(1))
.init()
lastEvents() should ===(Set.empty)
sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))
source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source)))
// Token enters merge-balance cycle and spins until event limit
// Without the limit this would spin forever (where forever = Int.MaxValue iterations)
source.onNext(2, eventLimit = 1000)
lastEvents() should ===(Set(RequestOne(source)))
// The cycle is still alive and kicking, just suspended due to the event limit
interpreter.isSuspended should be(true)
// Do to the fairness properties of both the interpreter event queue and the balance stage
// the element will eventually leave the cycle and reaches the sink.
// This should not hang even though we do not have an event limit set
sink.requestOne()
lastEvents() should ===(Set(OnNext(sink, 2)))
// The cycle is now empty
interpreter.isSuspended should be(false)
}
}
}
object GraphInterpreterSpec {
sealed trait TestEvent {
def source: GraphStageLogic
}
case class OnComplete(source: GraphStageLogic) extends TestEvent
case class Cancel(source: GraphStageLogic) extends TestEvent
case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent
case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent
case class RequestOne(source: GraphStageLogic) extends TestEvent
case class RequestAnother(source: GraphStageLogic) extends TestEvent
abstract class TestSetup {
private var lastEvent: Set[TestEvent] = Set.empty
private var _interpreter: GraphInterpreter = _
protected def interpreter: GraphInterpreter = _interpreter
class AssemblyBuilder(stages: Seq[GraphStage[_ <: Shape]]) {
var upstreams = Vector.empty[(UpstreamBoundaryStageLogic[_], Inlet[_])]
var downstreams = Vector.empty[(Outlet[_], DownstreamBoundaryStageLogic[_])]
var connections = Vector.empty[(Outlet[_], Inlet[_])]
def connect[T](upstream: UpstreamBoundaryStageLogic[T], in: Inlet[T]): AssemblyBuilder = {
upstreams :+= upstream -> in
this
}
def connect[T](out: Outlet[T], downstream: DownstreamBoundaryStageLogic[T]): AssemblyBuilder = {
downstreams :+= out -> downstream
this
}
def connect[T](out: Outlet[T], in: Inlet[T]): AssemblyBuilder = {
connections :+= out -> in
this
}
def init(): Unit = {
val ins = upstreams.map(_._2) ++ connections.map(_._2)
val outs = connections.map(_._1) ++ downstreams.map(_._1)
val inOwners = ins.map { in stages.indexWhere(_.shape.inlets.contains(in)) }
val outOwners = outs.map { out stages.indexWhere(_.shape.outlets.contains(out)) }
val assembly = GraphAssembly(
stages.toArray,
(ins ++ Vector.fill(downstreams.size)(null)).toArray,
(inOwners ++ Vector.fill(downstreams.size)(-1)).toArray,
(Vector.fill(upstreams.size)(null) ++ outs).toArray,
(Vector.fill(upstreams.size)(-1) ++ outOwners).toArray)
_interpreter = new GraphInterpreter(assembly, (_, _, _) ())
for ((upstream, i) upstreams.zipWithIndex) {
_interpreter.attachUpstreamBoundary(i, upstream._1)
}
for ((downstream, i) downstreams.zipWithIndex) {
_interpreter.attachDownstreamBoundary(i + upstreams.size + connections.size, downstream._2)
}
_interpreter.init()
}
}
def manualInit(assembly: GraphAssembly): Unit = _interpreter = new GraphInterpreter(assembly, (_, _, _) ())
def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages.toSeq)
def lastEvents(): Set[TestEvent] = {
val result = lastEvent
lastEvent = Set.empty
result
}
case class UpstreamProbe[T](override val toString: String) extends UpstreamBoundaryStageLogic[T] {
val out = Outlet[T]("out")
setHandler(out, new OutHandler {
override def onPull(): Unit = lastEvent += RequestOne(UpstreamProbe.this)
})
def onNext(elem: T, eventLimit: Int = Int.MaxValue): Unit = {
if (GraphInterpreter.Debug) println(s"----- NEXT: $this $elem")
push(out, elem)
interpreter.execute(eventLimit)
}
}
case class DownstreamProbe[T](override val toString: String) extends DownstreamBoundaryStageLogic[T] {
val in = Inlet[T]("in")
setHandler(in, new InHandler {
override def onPush(): Unit = lastEvent += OnNext(DownstreamProbe.this, grab(in))
})
def requestOne(eventLimit: Int = Int.MaxValue): Unit = {
if (GraphInterpreter.Debug) println(s"----- REQ $this")
pull(in)
interpreter.execute(eventLimit)
}
}
}
}

View file

@ -13,7 +13,7 @@ import akka.stream.impl.GenJunctions.ZipWithModule
import akka.stream.impl.GenJunctions.UnzipWithModule import akka.stream.impl.GenJunctions.UnzipWithModule
import akka.stream.impl.Junctions._ import akka.stream.impl.Junctions._
import akka.stream.impl.StreamLayout.Module import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.fusing.ActorInterpreter import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule, ActorInterpreter }
import akka.stream.impl.io.SslTlsCipherActor import akka.stream.impl.io.SslTlsCipherActor
import akka.stream._ import akka.stream._
import akka.stream.io.SslTls.TlsModule import akka.stream.io.SslTls.TlsModule
@ -112,6 +112,20 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem,
assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn)) assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn))
assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn)) assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn))
case graph: GraphModule
val calculatedSettings = effectiveSettings(effectiveAttributes)
val props = ActorGraphInterpreter.props(graph.assembly, graph.shape, calculatedSettings)
val impl = actorOf(props, stageName(effectiveAttributes), calculatedSettings.dispatcher)
for ((inlet, i) graph.shape.inlets.iterator.zipWithIndex) {
val subscriber = new ActorGraphInterpreter.BoundarySubscriber(impl, i)
assignPort(inlet, subscriber)
}
for ((outlet, i) graph.shape.outlets.iterator.zipWithIndex) {
val publisher = new ActorPublisher[Any](impl) { override val wakeUpMsg = ActorGraphInterpreter.SubscribePending(i) }
impl ! ActorGraphInterpreter.ExposedPublisher(i, publisher)
assignPort(outlet, publisher)
}
case junction: JunctionModule case junction: JunctionModule
materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes)) materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes))
} }

View file

@ -0,0 +1,390 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
import akka.actor._
import akka.event.Logging
import akka.stream._
import akka.stream.impl.ReactiveStreamsCompliance._
import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic, GraphAssembly }
import akka.stream.impl.{ ActorPublisher, ReactiveStreamsCompliance }
import akka.stream.stage.{ GraphStageLogic, InHandler, OutHandler }
import org.reactivestreams.{ Subscriber, Subscription }
import scala.util.control.NonFatal
/**
* INTERNAL API
*/
private[stream] case class GraphModule(assembly: GraphAssembly, shape: Shape, attributes: Attributes) extends Module {
override def subModules: Set[Module] = Set.empty
override def withAttributes(newAttr: Attributes): Module = copy(attributes = newAttr)
override def carbonCopy: Module = copy()
override def replaceShape(s: Shape): Module = ???
}
/**
* INTERNAL API
*/
private[stream] object ActorGraphInterpreter {
trait BoundaryEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded
final case class OnError(id: Int, cause: Throwable) extends BoundaryEvent
final case class OnComplete(id: Int) extends BoundaryEvent
final case class OnNext(id: Int, e: Any) extends BoundaryEvent
final case class OnSubscribe(id: Int, subscription: Subscription) extends BoundaryEvent
final case class RequestMore(id: Int, demand: Long) extends BoundaryEvent
final case class Cancel(id: Int) extends BoundaryEvent
final case class SubscribePending(id: Int) extends BoundaryEvent
final case class ExposedPublisher(id: Int, publisher: ActorPublisher[Any]) extends BoundaryEvent
final case class AsyncInput(logic: GraphStageLogic, evt: Any, handler: (Any) Unit) extends BoundaryEvent
case object Resume extends BoundaryEvent
final class BoundarySubscription(val parent: ActorRef, val id: Int) extends Subscription {
override def request(elements: Long): Unit = parent ! RequestMore(id, elements)
override def cancel(): Unit = parent ! Cancel(id)
override def toString = "BoundarySubscription" + System.identityHashCode(this)
}
final class BoundarySubscriber(val parent: ActorRef, id: Int) extends Subscriber[Any] {
override def onError(cause: Throwable): Unit = {
ReactiveStreamsCompliance.requireNonNullException(cause)
parent ! OnError(id, cause)
}
override def onComplete(): Unit = parent ! OnComplete(id)
override def onNext(element: Any): Unit = {
ReactiveStreamsCompliance.requireNonNullElement(element)
parent ! OnNext(id, element)
}
override def onSubscribe(subscription: Subscription): Unit = {
ReactiveStreamsCompliance.requireNonNullSubscription(subscription)
parent ! OnSubscribe(id, subscription)
}
}
def props(assembly: GraphAssembly, shape: Shape, settings: ActorMaterializerSettings): Props =
Props(new ActorGraphInterpreter(assembly, shape, settings)).withDeploy(Deploy.local)
class BatchingActorInputBoundary(size: Int, id: Int) extends UpstreamBoundaryStageLogic[Any] {
require(size > 0, "buffer size cannot be zero")
require((size & (size - 1)) == 0, "buffer size must be a power of two")
private var upstream: Subscription = _
private val inputBuffer = Array.ofDim[AnyRef](size)
private var inputBufferElements = 0
private var nextInputElementCursor = 0
private var upstreamCompleted = false
private var downstreamCanceled = false
private val IndexMask = size - 1
private def requestBatchSize = math.max(1, inputBuffer.length / 2)
private var batchRemaining = requestBatchSize
val out: Outlet[Any] = Outlet[Any]("UpstreamBoundary" + id)
private def dequeue(): Any = {
val elem = inputBuffer(nextInputElementCursor)
require(elem ne null, "Internal queue must never contain a null")
inputBuffer(nextInputElementCursor) = null
batchRemaining -= 1
if (batchRemaining == 0 && !upstreamCompleted) {
tryRequest(upstream, requestBatchSize)
batchRemaining = requestBatchSize
}
inputBufferElements -= 1
nextInputElementCursor = (nextInputElementCursor + 1) & IndexMask
elem
}
private def clear(): Unit = {
java.util.Arrays.fill(inputBuffer, 0, inputBuffer.length, null)
inputBufferElements = 0
}
def cancel(): Unit = {
if (!upstreamCompleted) {
upstreamCompleted = true
if (upstream ne null) tryCancel(upstream)
clear()
}
}
def onNext(elem: Any): Unit = {
if (!upstreamCompleted) {
if (inputBufferElements == size) throw new IllegalStateException("Input buffer overrun")
inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef]
inputBufferElements += 1
if (isAvailable(out)) push(out, dequeue())
}
}
def onError(e: Throwable): Unit =
if (!upstreamCompleted) {
upstreamCompleted = true
clear()
fail(out, e)
}
// Call this when an error happens that does not come from the usual onError channel
// (exceptions while calling RS interfaces, abrupt termination etc)
def onInternalError(e: Throwable): Unit = {
if (!(upstreamCompleted || downstreamCanceled) && (upstream ne null)) {
upstream.cancel()
}
onError(e)
}
def onComplete(): Unit =
if (!upstreamCompleted) {
upstreamCompleted = true
if (inputBufferElements == 0) complete(out)
}
def onSubscribe(subscription: Subscription): Unit = {
require(subscription != null, "Subscription cannot be null")
if (upstreamCompleted)
tryCancel(subscription)
else if (downstreamCanceled) {
upstreamCompleted = true
tryCancel(subscription)
} else {
upstream = subscription
// Prefetch
tryRequest(upstream, inputBuffer.length)
}
}
setHandler(out, new OutHandler {
override def onPull(): Unit = {
if (inputBufferElements > 1) push(out, dequeue())
else if (inputBufferElements == 1) {
if (upstreamCompleted) {
push(out, dequeue())
complete(out)
} else push(out, dequeue())
} else if (upstreamCompleted) {
complete(out)
}
}
override def onDownstreamFinish(): Unit = cancel()
})
}
class ActorOutputBoundary(actor: ActorRef, id: Int) extends DownstreamBoundaryStageLogic[Any] {
val in: Inlet[Any] = Inlet[Any]("UpstreamBoundary" + id)
private var exposedPublisher: ActorPublisher[Any] = _
private var subscriber: Subscriber[Any] = _
private var downstreamDemand: Long = 0L
// This flag is only used if complete/fail is called externally since this op turns into a Finished one inside the
// interpreter (i.e. inside this op this flag has no effects since if it is completed the op will not be invoked)
private var downstreamCompleted = false
// this is true while we hold the ball; while false incoming demand will just be queued up
private var upstreamWaiting = true
// when upstream failed before we got the exposed publisher
private var upstreamFailed: Option[Throwable] = None
private def onNext(elem: Any): Unit = {
downstreamDemand -= 1
tryOnNext(subscriber, elem)
}
private def complete(): Unit = {
if (!downstreamCompleted) {
downstreamCompleted = true
if (exposedPublisher ne null) exposedPublisher.shutdown(None)
if (subscriber ne null) tryOnComplete(subscriber)
}
}
def fail(e: Throwable): Unit = {
if (!downstreamCompleted) {
downstreamCompleted = true
if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e))
if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e)
} else if (exposedPublisher == null && upstreamFailed.isEmpty) {
// fail called before the exposed publisher arrived, we must store it and fail when we're first able to
upstreamFailed = Some(e)
}
}
setHandler(in, new InHandler {
override def onPush(): Unit = {
onNext(grab(in))
if (downstreamCompleted) cancel(in)
else if (downstreamDemand > 0) pull(in)
}
override def onUpstreamFinish(): Unit = complete()
override def onUpstreamFailure(cause: Throwable): Unit = fail(cause)
})
def subscribePending(): Unit =
exposedPublisher.takePendingSubscribers() foreach { sub
if (subscriber eq null) {
subscriber = sub
tryOnSubscribe(subscriber, new BoundarySubscription(actor, id))
} else
rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}")
}
def exposedPublisher(publisher: ActorPublisher[Any]): Unit = {
upstreamFailed match {
case _: Some[_]
publisher.shutdown(upstreamFailed)
case _
exposedPublisher = publisher
}
}
def requestMore(elements: Long): Unit = {
if (elements < 1) {
cancel(in)
fail(ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException)
} else {
downstreamDemand += elements
if (downstreamDemand < 0)
downstreamDemand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded
if (!hasBeenPulled(in)) pull(in)
}
}
def cancel(): Unit = {
downstreamCompleted = true
subscriber = null
exposedPublisher.shutdown(Some(new ActorPublisher.NormalShutdownException))
cancel(in)
}
}
}
/**
* INTERNAL API
*/
private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shape, settings: ActorMaterializerSettings) extends Actor {
import ActorGraphInterpreter._
val interpreter = new GraphInterpreter(assembly, (logic, event, handler) self ! AsyncInput(logic, event, handler))
val inputs = Array.tabulate(shape.inlets.size)(new BatchingActorInputBoundary(settings.maxInputBufferSize, _))
val outputs = Array.tabulate(shape.outlets.size)(new ActorOutputBoundary(self, _))
// Limits the number of events processed by the interpreter before scheduling a self-message for fairness with other
// actors.
// TODO: Better heuristic here
val eventLimit = settings.maxInputBufferSize * assembly.stages.length * 4 // Roughly 4 events per element transfer
// Limits the number of events processed by the interpreter on an abort event.
// TODO: Better heuristic here
val abortLimit = eventLimit * 2
var resumeScheduled = false
override def preStart(): Unit = {
var i = 0
while (i < inputs.length) {
interpreter.attachUpstreamBoundary(i, inputs(i))
i += 1
}
val offset = assembly.connectionCount - outputs.length
i = 0
while (i < outputs.length) {
interpreter.attachDownstreamBoundary(i + offset, outputs(i))
i += 1
}
interpreter.init()
}
override def receive: Receive = {
case OnError(id: Int, cause: Throwable)
if (GraphInterpreter.Debug) println(s" onError id=$id")
inputs(id).onError(cause)
runBatch()
case OnComplete(id: Int)
if (GraphInterpreter.Debug) println(s" onComplete id=$id")
inputs(id).onComplete()
runBatch()
case OnNext(id: Int, e: Any)
if (GraphInterpreter.Debug) println(s" onNext $e id=$id")
inputs(id).onNext(e)
runBatch()
case OnSubscribe(id: Int, subscription: Subscription)
inputs(id).onSubscribe(subscription)
case RequestMore(id: Int, demand: Long)
if (GraphInterpreter.Debug) println(s" request $demand id=$id")
outputs(id).requestMore(demand)
runBatch()
case Cancel(id: Int)
if (GraphInterpreter.Debug) println(s" cancel id=$id")
outputs(id).cancel()
runBatch()
case SubscribePending(id: Int)
outputs(id).subscribePending()
case ExposedPublisher(id, publisher)
outputs(id).exposedPublisher(publisher)
case AsyncInput(_, event, handler)
if (GraphInterpreter.Debug) println(s"ASYNC $event")
handler(event)
runBatch()
case Resume
resumeScheduled = false
if (interpreter.isSuspended) runBatch()
}
override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = {
super.aroundReceive(receive, msg)
}
private def runBatch(): Unit = {
try {
interpreter.execute(eventLimit)
if (interpreter.isCompleted) context.stop(self)
else if (interpreter.isSuspended && !resumeScheduled) {
resumeScheduled = true
self ! Resume
}
} catch {
case NonFatal(e)
context.stop(self)
tryAbort(e)
}
}
/**
* Attempts to abort execution, by first propagating the reason given until either
* - the interpreter successfully finishes
* - the event limit is reached
* - a new error is encountered
*/
private def tryAbort(ex: Throwable): Unit = {
// This should handle termination while interpreter is running. If the upstream have been closed already this
// call has no effect and therefore do the right thing: nothing.
try {
inputs.foreach(_.onInternalError(ex))
interpreter.execute(abortLimit)
interpreter.finish()
} // Will only have an effect if the above call to the interpreter failed to emit a proper failure to the downstream
// otherwise this will have no effect
finally {
outputs.foreach(_.fail(ex))
inputs.foreach(_.cancel())
}
}
override def postStop(): Unit = tryAbort(AbruptTerminationException(self))
}

View file

@ -0,0 +1,434 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
import akka.stream.stage.{ OutHandler, InHandler, GraphStage, GraphStageLogic }
import akka.stream.{ Shape, Inlet, Outlet }
/**
* INTERNAL API
*
* (See the class for the documentation of the internals)
*/
private[stream] object GraphInterpreter {
/**
* Compile time constant, enable it for debug logging to the console.
*/
final val Debug = false
/**
* Marker object that indicates that a port holds no element since it was already grabbed. The port is still pullable,
* but there is no more element to grab.
*/
case object Empty
sealed trait ConnectionState
sealed trait CompletedState extends ConnectionState
case object Pushable extends ConnectionState
case object Completed extends CompletedState
final case class PushCompleted(element: Any) extends ConnectionState
case object Cancelled extends CompletedState
final case class Failed(ex: Throwable) extends CompletedState
val NoEvent = -1
val Boundary = -1
abstract class UpstreamBoundaryStageLogic[T] extends GraphStageLogic {
def out: Outlet[T]
}
abstract class DownstreamBoundaryStageLogic[T] extends GraphStageLogic {
def in: Inlet[T]
}
/**
* INTERNAL API
*
* A GraphAssembly represents a small stream processing graph to be executed by the interpreter. Instances of this
* class **must not** be mutated after construction.
*
* The arrays [[ins]] and [[outs]] correspond to the notion of a *connection* in the [[GraphInterpreter]]. Each slot
* *i* contains the input and output port corresponding to connection *i*. Slots where the graph is not closed (i.e.
* ports are exposed to the external world) are marked with *null* values. For example if an input port *p* is
* exposed, then outs(p) will contain a *null*.
*
* The arrays [[inOwners]] and [[outOwners]] are lookup tables from a connection id (the index of the slot)
* to a slot in the [[stages]] array, indicating which stage is the owner of the given input or output port.
* Slots which would correspond to non-existent stages (where the corresponding port is null since it represents
* the currently unknown external context) contain the value [[GraphInterpreter#Boundary]].
*
* The current assumption by the infrastructure is that the layout of these arrays looks like this:
*
* +---------------------------------------+-----------------+
* inOwners: | index to stages array | Boundary (-1) |
* +----------------+----------------------+-----------------+
* ins: | exposed inputs | internal connections | nulls |
* +----------------+----------------------+-----------------+
* outs: | nulls | internal connections | exposed outputs |
* +----------------+----------------------+-----------------+
* outOwners: | Boundary (-1) | index to stages array |
* +----------------+----------------------------------------+
*
* In addition, it is also assumed by the infrastructure that the order of exposed inputs and outputs in the
* corresponding segments of these arrays matches the exact same order of the ports in the [[Shape]].
*
*/
final case class GraphAssembly(stages: Array[GraphStage[_]],
ins: Array[Inlet[_]],
inOwners: Array[Int],
outs: Array[Outlet[_]],
outOwners: Array[Int]) {
val connectionCount: Int = ins.length
/**
* Takes an interpreter and returns three arrays required by the interpreter containing the input, output port
* handlers and the stage logic instances.
*/
def materialize(interpreter: GraphInterpreter): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic]) = {
val logics = Array.ofDim[GraphStageLogic](stages.length)
for (i stages.indices) {
logics(i) = stages(i).createLogic
logics(i).interpreter = interpreter
}
val inHandlers = Array.ofDim[InHandler](connectionCount)
val outHandlers = Array.ofDim[OutHandler](connectionCount)
for (i 0 until connectionCount) {
if (ins(i) ne null) {
inHandlers(i) = logics(inOwners(i)).inHandlers(ins(i))
logics(inOwners(i)).inToConn += ins(i) -> i
}
if (outs(i) ne null) {
outHandlers(i) = logics(outOwners(i)).outHandlers(outs(i))
logics(outOwners(i)).outToConn += outs(i) -> i
}
}
(inHandlers, outHandlers, logics)
}
override def toString: String =
"GraphAssembly(" +
stages.mkString("[", ",", "]") + ", " +
ins.mkString("[", ",", "]") + ", " +
inOwners.mkString("[", ",", "]") + ", " +
outs.mkString("[", ",", "]") + ", " +
outOwners.mkString("[", ",", "]") +
")"
}
}
/**
* INERNAL API
*
* From an external viewpoint, the GraphInterpreter takes an assembly of graph processing stages encoded as a
* [[GraphInterpreter#GraphAssembly]] object and provides facilities to execute and interact with this assembly.
* The lifecylce of the Interpreter is roughly the following:
* - Boundary logics are attached via [[attachDownstreamBoundary()]] and [[attachUpstreamBoundary()]]
* - [[init()]] is called
* - [[execute()]] is called whenever there is need for execution, providing an upper limit on the processed events
* - [[finish()]] is called before the interpreter is disposed, preferably after [[isCompleted]] returned true, although
* in abort cases this is not strictly necessary
*
* The [[execute()]] method of the interpreter accepts an upper bound on the events it will process. After this limit
* is reached or there are no more pending events to be processed, the call returns. It is possible to inspect
* if there are unprocessed events left via the [[isSuspended]] method. [[isCompleted]] returns true once all stages
* reported completion inside the interpreter.
*
* The internal architecture of the interpreter is based on the usage of arrays and optimized for reducing allocations
* on the hot paths.
*
* One of the basic abstractions inside the interpreter is the notion of *connection*. In the abstract sense a
* connection represents an output-input port pair (an analogue for a connected RS Publisher-Subscriber pair),
* while in the practical sense a connection is a number which represents slots in certain arrays.
* In particular
* - connectionStates is a mapping from a connection id to a current (or future) state of the connection
* - inAvailable is a mapping from a connection to a boolean that indicates whether the input corresponding
* to the connection is currently pullable
* - outAvailable is a mapping from a connection to a boolean that indicates whether the input corresponding
* to the connection is currently pushable
* - inHandlers is a mapping from a connection id to the [[InHandler]] instance that handles the events corresponding
* to the input port of the connection
* - outHandlers is a mapping from a connection id to the [[OutHandler]] instance that handles the events corresponding
* to the output port of the connection
*
* On top of these lookup tables there is an eventQueue, represented as a circular buffer of integers. The integers
* it contains represents connections that have pending events to be processed. The pending event itself is encoded
* in the connectionStates table. This implies that there can be only one event in flight for a given connection, which
* is true in almost all cases, except a complete-after-push which is therefore handled with a special event
* [[GraphInterpreter#PushCompleted]].
*
* Sending an event is usually the following sequence:
* - An action is requested by a stage logic (push, pull, complete, etc.)
* - the availability of the port is set on the sender side to false (inAvailable or outAvailable)
* - the scheduled event is put in the slot of the connection in the connectionStates table
* - the id of the affected connection is enqueued
*
* Receiving an event is usually the following sequence:
* - id of connection to be processed is dequeued
* - the type of the event is determined by the object in the corresponding connectionStates slot
* - the availability of the port is set on the receiver side to be true (inAvailable or outAvailable)
* - using the inHandlers/outHandlers table the corresponding callback is called on the stage logic.
*
* Because of the FIFO construction of the queue the interpreter is fair, i.e. a pending event is always executed
* after a bounded number of other events. This property, together with suspendability means that even infinite cycles can
* be modeled, or even dissolved (if preempted and a "stealing" external even is injected; for example the non-cycle
* edge of a balance is pulled, dissolving the original cycle).
*/
private[stream] final class GraphInterpreter(
private val assembly: GraphInterpreter.GraphAssembly,
val onAsyncInput: (GraphStageLogic, Any, (Any) Unit) Unit) {
import GraphInterpreter._
// Maintains the next event (and state) of the connection.
// Technically the connection cannot be considered being in the state that is encoded here before the enqueued
// connection event has been processed. The inAvailable and outAvailable arrays usually protect access to this
// field while it is in transient state.
val connectionStates = Array.fill[Any](assembly.connectionCount)(Empty)
// Indicates whether the input port is pullable. After pulling it becomes false
// Be aware that when inAvailable goes to false outAvailable does not become true immediately, only after
// the corresponding event in the queue has been processed
val inAvailable = Array.fill[Boolean](assembly.connectionCount)(true)
// Indicates whether the output port is pushable. After pushing it becomes false
// Be aware that when inAvailable goes to false outAvailable does not become true immediately, only after
// the corresponding event in the queue has been processed
val outAvailable = Array.fill[Boolean](assembly.connectionCount)(false)
// Lookup tables for the InHandler and OutHandler for a given connection ID, and a lookup table for the
// GraphStageLogic instances
val (inHandlers, outHandlers, logics) = assembly.materialize(this)
// The number of currently running stages. Once this counter reaches zero, the interpreter is considered to be
// completed
private var runningStages = assembly.stages.length
// Counts how many active connections a stage has. Once it reaches zero, the stage is automatically stopped.
private val shutdownCounter = Array.tabulate(assembly.stages.length) { i
val shape = assembly.stages(i).shape.asInstanceOf[Shape]
shape.inlets.size + shape.outlets.size
}
// An event queue implemented as a circular buffer
private val mask = 255
private val eventQueue = Array.ofDim[Int](256)
private var queueHead: Int = 0
private var queueTail: Int = 0
/**
* Assign the boundary logic to a given connection. This will serve as the interface to the external world
* (outside the interpreter) to process and inject events.
*/
def attachUpstreamBoundary(connection: Int, logic: UpstreamBoundaryStageLogic[_]): Unit = {
logic.outToConn += logic.out -> connection
logic.interpreter = this
outHandlers(connection) = logic.outHandlers.head._2
}
/**
* Assign the boundary logic to a given connection. This will serve as the interface to the external world
* (outside the interpreter) to process and inject events.
*/
def attachDownstreamBoundary(connection: Int, logic: DownstreamBoundaryStageLogic[_]): Unit = {
logic.inToConn += logic.in -> connection
logic.interpreter = this
inHandlers(connection) = logic.inHandlers.head._2
}
/**
* Returns true if there are pending unprocessed events in the event queue.
*/
def isSuspended: Boolean = queueHead != queueTail
/**
* Returns true if there are no more running stages and pending events.
*/
def isCompleted: Boolean = runningStages == 0 && !isSuspended
/**
* Initializes the states of all the stage logics by calling preStart()
*/
def init(): Unit = {
var i = 0
while (i < logics.length) {
logics(i).preStart()
i += 1
}
}
/**
* Finalizes the state of all stages by calling postStop() (if necessary).
*/
def finish(): Unit = {
var i = 0
while (i < logics.length) {
if (!isStageCompleted(i)) logics(i).postStop()
i += 1
}
}
// Debug name for a connections input part
private def inOwnerName(connection: Int): String =
if (assembly.inOwners(connection) == Boundary) "DownstreamBoundary"
else assembly.stages(assembly.inOwners(connection)).toString
// Debug name for a connections ouput part
private def outOwnerName(connection: Int): String =
if (assembly.outOwners(connection) == Boundary) "UpstreamBoundary"
else assembly.stages(assembly.outOwners(connection)).toString
/**
* Executes pending events until the given limit is met. If there were remaining events, isSuspended will return
* true.
*/
def execute(eventLimit: Int): Unit = {
var eventsRemaining = eventLimit
var connection = dequeue()
while (eventsRemaining > 0 && connection != NoEvent) {
processEvent(connection)
eventsRemaining -= 1
if (eventsRemaining > 0) connection = dequeue()
}
// TODO: deadlock detection
}
// Decodes and processes a single event for the given connection
private def processEvent(connection: Int): Unit = {
def processElement(elem: Any): Unit = {
if (!isStageCompleted(assembly.inOwners(connection))) {
if (GraphInterpreter.Debug) println(s"PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, $elem")
inAvailable(connection) = true
inHandlers(connection).onPush()
}
}
connectionStates(connection) match {
case Pushable
if (!isStageCompleted(assembly.outOwners(connection))) {
if (GraphInterpreter.Debug) println(s"PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)}")
outAvailable(connection) = true
outHandlers(connection).onPull()
}
case Completed
val stageId = assembly.inOwners(connection)
if (!isStageCompleted(stageId)) {
if (GraphInterpreter.Debug) println(s"COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)}")
inAvailable(connection) = false
inHandlers(connection).onUpstreamFinish()
completeConnection(stageId)
}
case Failed(ex)
val stageId = assembly.inOwners(connection)
if (!isStageCompleted(stageId)) {
if (GraphInterpreter.Debug) println(s"FAIL ${outOwnerName(connection)} -> ${inOwnerName(connection)}")
inAvailable(connection) = false
inHandlers(connection).onUpstreamFailure(ex)
completeConnection(stageId)
}
case Cancelled
val stageId = assembly.outOwners(connection)
if (!isStageCompleted(stageId)) {
if (GraphInterpreter.Debug) println(s"CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)}")
outAvailable(connection) = false
outHandlers(connection).onDownstreamFinish()
completeConnection(stageId)
}
case PushCompleted(elem)
inAvailable(connection) = true
connectionStates(connection) = elem
processElement(elem)
enqueue(connection, Completed)
case pushedElem processElement(pushedElem)
}
}
private def dequeue(): Int = {
if (queueHead == queueTail) NoEvent
else {
val idx = queueHead & mask
val elem = eventQueue(idx)
eventQueue(idx) = NoEvent
queueHead += 1
elem
}
}
private def enqueue(connection: Int, event: Any): Unit = {
connectionStates(connection) = event
eventQueue(queueTail & mask) = connection
queueTail += 1
}
// Returns true if a connection has been completed *or if the completion event is already enqueued*. This is useful
// to prevent redundant completion events in case of concurrent invocation on both sides of the connection.
// I.e. when one side already enqueued the completion event, then the other side will not enqueue the event since
// there is noone to process it anymore.
def isConnectionCompleted(connection: Int): Boolean = connectionStates(connection).isInstanceOf[CompletedState]
// Returns true if the given stage is alredy completed
private def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0
private def isPushInFlight(connection: Int): Boolean =
!inAvailable(connection) &&
!connectionStates(connection).isInstanceOf[ConnectionState] &&
connectionStates(connection) != Empty
// Register that a connection in which the given stage participated has been completed and therefore the stage
// itself might stop, too.
private def completeConnection(stageId: Int): Unit = {
if (stageId != Boundary) {
val activeConnections = shutdownCounter(stageId)
if (activeConnections > 0) {
shutdownCounter(stageId) = activeConnections - 1
// This was the last active connection keeping this stage alive
if (activeConnections == 1) {
runningStages -= 1
logics(stageId).postStop()
}
}
}
}
private[stream] def push(connection: Int, elem: Any): Unit = {
outAvailable(connection) = false
enqueue(connection, elem)
}
private[stream] def pull(connection: Int): Unit = {
inAvailable(connection) = false
enqueue(connection, Pushable)
}
private[stream] def complete(connection: Int): Unit = {
outAvailable(connection) = false
if (!isConnectionCompleted(connection)) {
// There is a pending push, we change the signal to be a PushCompleted (there can be only one signal in flight
// for a connection)
if (isPushInFlight(connection))
connectionStates(connection) = PushCompleted(connectionStates(connection))
else
enqueue(connection, Completed)
}
completeConnection(assembly.outOwners(connection))
}
private[stream] def fail(connection: Int, ex: Throwable): Unit = {
outAvailable(connection) = false
if (!isConnectionCompleted(connection)) enqueue(connection, Failed(ex))
completeConnection(assembly.outOwners(connection))
}
private[stream] def cancel(connection: Int): Unit = {
inAvailable(connection) = false
if (!isConnectionCompleted(connection)) enqueue(connection, Cancelled)
completeConnection(assembly.inOwners(connection))
}
}

View file

@ -0,0 +1,235 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
import akka.stream._
import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage }
/**
* INTERNAL API
*/
object GraphStages {
class Identity[T] extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("in")
val out = Outlet[T]("out")
val shape = FlowShape(in, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in, new InHandler {
override def onPush(): Unit = push(out, grab(in))
override def onUpstreamFinish(): Unit = completeStage()
override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex)
})
setHandler(out, new OutHandler {
override def onPull(): Unit = pull(in)
override def onDownstreamFinish(): Unit = completeStage()
})
}
override def toString = "Identity"
}
class Detacher[T] extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("in")
val out = Outlet[T]("out")
val shape = FlowShape(in, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
var initialized = false
setHandler(in, new InHandler {
override def onPush(): Unit = {
if (isAvailable(out)) {
push(out, grab(in))
pull(in)
}
}
})
setHandler(out, new OutHandler {
override def onPull(): Unit = {
if (!initialized) {
pull(in)
initialized = true
} else if (isAvailable(in)) {
push(out, grab(in))
if (!hasBeenPulled(in)) pull(in)
}
}
})
}
override def toString = "Detacher"
}
class Broadcast[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in = Inlet[T]("in")
val out = Vector.fill(outCount)(Outlet[T]("out"))
val shape = UniformFanOutShape(in, out: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var pending = outCount
setHandler(in, new InHandler {
override def onPush(): Unit = {
pending = outCount
val elem = grab(in)
out.foreach(push(_, elem))
}
})
val outHandler = new OutHandler {
override def onPull(): Unit = {
pending -= 1
if (pending == 0) pull(in)
}
}
out.foreach(setHandler(_, outHandler))
}
override def toString = "Broadcast"
}
class Zip[A, B] extends GraphStage[FanInShape2[A, B, (A, B)]] {
val in0 = Inlet[A]("in0")
val in1 = Inlet[B]("in1")
val out = Outlet[(A, B)]("out")
val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
var pending = 2
val inHandler = new InHandler {
override def onPush(): Unit = {
pending -= 1
if (pending == 0) push(out, (grab(in0), grab(in1)))
}
}
setHandler(in0, inHandler)
setHandler(in1, inHandler)
setHandler(out, new OutHandler {
override def onPull(): Unit = {
pending = 2
pull(in0)
pull(in1)
}
})
}
override def toString = "Zip"
}
class Merge[T](private val inCount: Int) extends GraphStage[UniformFanInShape[T, T]] {
val in = Vector.fill(inCount)(Inlet[T]("in"))
val out = Outlet[T]("out")
val shape = UniformFanInShape(out, in: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var initialized = false
private val pendingQueue = Array.ofDim[Inlet[T]](inCount)
private var pendingHead: Int = 0
private var pendingTail: Int = 0
private def noPending: Boolean = pendingHead == pendingTail
private def enqueue(in: Inlet[T]): Unit = {
pendingQueue(pendingTail % inCount) = in
pendingTail += 1
}
private def dequeueAndDispatch(): Unit = {
val in = pendingQueue(pendingHead % inCount)
pendingHead += 1
push(out, grab(in))
pull(in)
}
in.foreach { i
setHandler(i, new InHandler {
override def onPush(): Unit = {
if (isAvailable(out)) {
if (noPending) {
push(out, grab(i))
pull(i)
} else {
enqueue(i)
dequeueAndDispatch()
}
} else enqueue(i)
}
})
}
setHandler(out, new OutHandler {
override def onPull(): Unit = {
if (!initialized) {
initialized = true
in.foreach(pull(_))
} else {
if (!noPending) {
dequeueAndDispatch()
}
}
}
})
}
override def toString = "Merge"
}
class Balance[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in = Inlet[T]("in")
val out = Vector.fill(outCount)(Outlet[T]("out"))
val shape = UniformFanOutShape[T, T](in, out: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private val pendingQueue = Array.ofDim[Outlet[T]](outCount)
private var pendingHead: Int = 0
private var pendingTail: Int = 0
private def noPending: Boolean = pendingHead == pendingTail
private def enqueue(out: Outlet[T]): Unit = {
pendingQueue(pendingTail % outCount) = out
pendingTail += 1
}
private def dequeueAndDispatch(): Unit = {
val out = pendingQueue(pendingHead % outCount)
pendingHead += 1
push(out, grab(in))
if (!noPending) pull(in)
}
setHandler(in, new InHandler {
override def onPush(): Unit = dequeueAndDispatch()
})
out.foreach { o
setHandler(o, new OutHandler {
override def onPull(): Unit = {
if (isAvailable(in)) {
if (noPending) {
push(o, grab(in))
} else {
enqueue(o)
dequeueAndDispatch()
}
} else {
if (!hasBeenPulled(in)) pull(in)
enqueue(o)
}
}
})
}
}
override def toString = "Balance"
}
}

View file

@ -0,0 +1,255 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.stage
import akka.stream._
import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.fusing.{ GraphModule, GraphInterpreter }
import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly
/**
* A GraphStage represents a reusable graph stream processing stage. A GraphStage consists of a [[Shape]] which describes
* its input and output ports and a factory function that creates a [[GraphStageLogic]] which implements the processing
* logic that ties the ports together.
*/
abstract class GraphStage[S <: Shape] extends Graph[S, Unit] {
def shape: S
def createLogic: GraphStageLogic
final override private[stream] lazy val module: Module = {
val connectionCount = shape.inlets.size + shape.outlets.size
val assembly = GraphAssembly(
Array(this),
Array.ofDim(connectionCount),
Array.fill(connectionCount)(-1),
Array.ofDim(connectionCount),
Array.fill(connectionCount)(-1))
for ((inlet, i) shape.inlets.iterator.zipWithIndex) {
assembly.ins(i) = inlet
assembly.inOwners(i) = 0
}
for ((outlet, i) shape.outlets.iterator.zipWithIndex) {
assembly.outs(i + shape.inlets.size) = outlet
assembly.outOwners(i + shape.inlets.size) = 0
}
GraphModule(assembly, shape, Attributes.none)
}
/**
* This method throws an [[UnsupportedOperationException]] by default. The subclass can override this method
* and provide a correct implementation that creates an exact copy of the stage with the provided new attributes.
*/
override def withAttributes(attr: Attributes): Graph[S, Unit] =
throw new UnsupportedOperationException(
"withAttributes not supported by default by FlexiMerge, subclass may override and implement it")
}
/**
* Represents the processing logic behind a [[GraphStage]]. Roughly speaking, a subclass of [[GraphStageLogic]] is a
* collection of the following parts:
* * A set of [[InHandler]] and [[OutHandler]] instances and their assignments to the [[Inlet]]s and [[Outlet]]s
* of the enclosing [[GraphStage]]
* * Possible mutable state, accessible from the [[InHandler]] and [[OutHandler]] callbacks, but not from anywhere
* else (as such access would not be thread-safe)
* * The lifecycle hooks [[preStart()]] and [[postStop()]]
* * Methods for performing stream processing actions, like pulling or pushing elements
*
* The stage logic is always stopped once all its input and output ports have been closed, i.e. it is not possible to
* keep the stage alive for further processing once it does not have any open ports.
*/
abstract class GraphStageLogic {
import GraphInterpreter._
/**
* INTERNAL API
*/
private[stream] var inHandlers = scala.collection.Map.empty[Inlet[_], InHandler]
/**
* INTERNAL API
*/
private[stream] var outHandlers = scala.collection.Map.empty[Outlet[_], OutHandler]
/**
* INTERNAL API
*/
private[stream] var inToConn = scala.collection.Map.empty[Inlet[_], Int]
/**
* INTERNAL API
*/
private[stream] var outToConn = scala.collection.Map.empty[Outlet[_], Int]
/**
* INTERNAL API
*/
private[stream] var interpreter: GraphInterpreter = _
/**
* Assigns callbacks for the events for an [[Inlet]]
*/
final protected def setHandler(in: Inlet[_], handler: InHandler): Unit = inHandlers += in -> handler
/**
* Assigns callbacks for the events for an [[Outlet]]
*/
final protected def setHandler(out: Outlet[_], handler: OutHandler): Unit = outHandlers += out -> handler
private def conn[T](in: Inlet[T]): Int = inToConn(in)
private def conn[T](out: Outlet[T]): Int = outToConn(out)
/**
* Requests an element on the given port. Calling this method twice before an element arrived will fail.
* There can only be one outstanding request at any given time. The method [[hasBeenPulled()]] can be used
* query whether pull is allowed to be called or not.
*/
final def pull[T](in: Inlet[T]): Unit = {
require(!hasBeenPulled(in), "Cannot pull port twice")
interpreter.pull(conn(in))
}
/**
* Requests to stop receiving events from a given input port.
*/
final def cancel[T](in: Inlet[T]): Unit = interpreter.cancel(conn(in))
/**
* Once the callback [[InHandler.onPush()]] for an input port has been invoked, the element that has been pushed
* can be retrieved via this method. After [[grab()]] has been called the port is considered to be empty, and further
* calls to [[grab()]] will fail until the port is pulled again and a new element is pushed as a response.
*
* The method [[isAvailable()]] can be used to query if the port has an element that can be grabbed or not.
*/
final def grab[T](in: Inlet[T]): T = {
require(isAvailable(in), "Cannot get element from already empty input port")
val connection = conn(in)
val elem = interpreter.connectionStates(connection)
interpreter.connectionStates(connection) = Empty
elem.asInstanceOf[T]
}
/**
* Indicates whether there is already a pending pull for the given input port. If this method returns true
* then [[isAvailable()]] must return false for that same port.
*/
final def hasBeenPulled[T](in: Inlet[T]): Boolean = !interpreter.inAvailable(conn(in))
/**
* Indicates whether there is an element waiting at the given input port. [[grab()]] can be used to retrieve the
* element. After calling [[grab()]] this method will return false.
*
* If this method returns true then [[hasBeenPulled()]] will return false for that same port.
*/
final def isAvailable[T](in: Inlet[T]): Boolean = {
val connection = conn(in)
interpreter.inAvailable(connection) && !(interpreter.connectionStates(connection) == Empty)
}
/**
* Emits an element through the given output port. Calling this method twice before a [[pull()]] has been arrived
* will fail. There can be only one outstanding push request at any given time. The method [[isAvailable()]] can be
* used to check if the port is ready to be pushed or not.
*/
final def push[T](out: Outlet[T], elem: T): Unit = {
require(isAvailable(out), "Cannot push port twice")
interpreter.push(conn(out), elem)
}
/**
* Signals that there will be no more elements emitted on the given port.
*/
final def complete[T](out: Outlet[T]): Unit = interpreter.complete(conn(out))
/**
* Signals failure through the given port.
*/
final def fail[T](out: Outlet[T], ex: Throwable): Unit = interpreter.fail(conn(out), ex)
/**
* Automatically invokes [[cancel()]] or [[complete()]] on all the input or output ports that have been called,
* then stops the stage, then [[postStop()]] is called.
*/
final def completeStage(): Unit = {
inToConn.valuesIterator.foreach(interpreter.cancel)
outToConn.valuesIterator.foreach(interpreter.complete)
}
/**
* Automatically invokes [[cancel()]] or [[fail()]] on all the input or output ports that have been called,
* then stops the stage, then [[postStop()]] is called.
*/
final def failStage(ex: Throwable): Unit = {
inToConn.valuesIterator.foreach(interpreter.cancel)
outToConn.valuesIterator.foreach(interpreter.fail(_, ex))
}
/**
* Return true if the given output port is ready to be pushed.
*/
final def isAvailable[T](out: Outlet[T]): Boolean = interpreter.outAvailable(conn(out))
/**
* Obtain a callback object that can be used asynchronously to re-enter the
* current [[AsyncStage]] with an asynchronous notification. The [[invoke()]] method of the returned
* [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely
* delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and
* the passed handler will be invoked eventually in a thread-safe way by the execution environment.
*
* This object can be cached and reused within the same [[GraphStageLogic]].
*/
final def getAsyncCallback[T](handler: T Unit): AsyncCallback[T] = {
new AsyncCallback[T] {
override def invoke(event: T): Unit =
interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any Unit])
}
}
/**
* Invoked before any external events are processed, at the startup of the stage.
*/
def preStart(): Unit = ()
/**
* Invoked after processing of external events stopped because the stage is about to stop or fail.
*/
def postStop(): Unit = ()
}
/**
* Collection of callbacks for an input port of a [[GraphStage]]
*/
trait InHandler {
/**
* Called when the input port has a new element available. The actual element can be retrieved via the
* [[GraphStageLogic.grab()]] method.
*/
def onPush(): Unit
/**
* Called when the input port is finished. After this callback no other callbacks will be called for this port.
*/
def onUpstreamFinish(): Unit = ()
/**
* Called when the input port has failed. After this callback no other callbacks will be called for this port.
*/
def onUpstreamFailure(ex: Throwable): Unit = ()
}
/**
* Collection of callbacks for an output port of a [[GraphStage]]
*/
trait OutHandler {
/**
* Called when the output port has received a pull, and therefore ready to emit an element, i.e. [[GraphStageLogic.push()]]
* is now allowed to be called on this port.
*/
def onPull(): Unit
/**
* Called when the output port will no longer accept any new elements. After this callback no other callbacks will
* be called for this port.
*/
def onDownstreamFinish(): Unit = ()
}