make fused graphs fusable
This commit is contained in:
parent
3d20915cf4
commit
e4f31b66c3
7 changed files with 308 additions and 95 deletions
|
|
@ -7,32 +7,82 @@ import akka.stream._
|
|||
import akka.stream.scaladsl._
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import akka.stream.Attributes._
|
||||
import akka.stream.Fusing.FusedGraph
|
||||
import scala.annotation.tailrec
|
||||
import akka.stream.impl.StreamLayout.Module
|
||||
|
||||
class FusingSpec extends AkkaSpec with ConversionCheckedTripleEquals {
|
||||
|
||||
final val Debug = false
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
def graph(async: Boolean) =
|
||||
Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1)
|
||||
.alsoTo(Flow[Int].fold(0)(_ + _).to(Sink.head.named("otherSink")).addAttributes(if (async) Attributes.asyncBoundary else Attributes.none))
|
||||
.via(Flow[Int].fold(1)(_ + _).named("mainSink"))
|
||||
|
||||
def singlePath[S <: Shape, M](fg: FusedGraph[S, M], from: Attribute, to: Attribute): Unit = {
|
||||
val starts = fg.module.info.allModules.filter(_.attributes.contains(from))
|
||||
starts.size should ===(1)
|
||||
val start = starts.head
|
||||
val ups = fg.module.info.upstreams
|
||||
val owner = fg.module.info.outOwners
|
||||
|
||||
@tailrec def rec(curr: Module): Unit = {
|
||||
if (Debug) println(extractName(curr, "unknown"))
|
||||
if (curr.attributes.contains(to)) () // done
|
||||
else {
|
||||
val outs = curr.inPorts.map(ups)
|
||||
outs.size should ===(1)
|
||||
val out = outs.head
|
||||
val next = owner(out)
|
||||
rec(next)
|
||||
}
|
||||
}
|
||||
|
||||
rec(start)
|
||||
}
|
||||
|
||||
"Fusing" must {
|
||||
|
||||
"fuse a moderately complex graph" in {
|
||||
val g = Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1).alsoTo(Sink.fold(0)(_ + _)).to(Sink.fold(1)(_ + _))
|
||||
val fused = Fusing.aggressive(g)
|
||||
def verify[S <: Shape, M](fused: FusedGraph[S, M], modules: Int, downstreams: Int): Unit = {
|
||||
val module = fused.module
|
||||
module.subModules.size should ===(1)
|
||||
module.info.downstreams.size should be > 5
|
||||
module.info.upstreams.size should be > 5
|
||||
module.subModules.size should ===(modules)
|
||||
module.downstreams.size should ===(modules - 1)
|
||||
module.info.downstreams.size should be >= downstreams
|
||||
module.info.upstreams.size should be >= downstreams
|
||||
singlePath(fused, Attributes.Name("mainSink"), Attributes.Name("unfoldInf"))
|
||||
singlePath(fused, Attributes.Name("otherSink"), Attributes.Name("unfoldInf"))
|
||||
}
|
||||
|
||||
"fuse a moderately complex graph" in {
|
||||
val g = graph(false)
|
||||
val fused = Fusing.aggressive(g)
|
||||
verify(fused, modules = 1, downstreams = 5)
|
||||
}
|
||||
|
||||
"not fuse across AsyncBoundary" in {
|
||||
val g =
|
||||
Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1)
|
||||
.alsoTo(Sink.fold(0)(_ + (_: Int)).addAttributes(Attributes.asyncBoundary))
|
||||
.to(Sink.fold(1)(_ + _))
|
||||
val g = graph(true)
|
||||
val fused = Fusing.aggressive(g)
|
||||
val module = fused.module
|
||||
module.subModules.size should ===(2)
|
||||
module.info.downstreams.size should be > 5
|
||||
module.info.upstreams.size should be > 5
|
||||
verify(fused, modules = 2, downstreams = 5)
|
||||
}
|
||||
|
||||
"not fuse a FusedGraph again" in {
|
||||
val g = Fusing.aggressive(graph(false))
|
||||
Fusing.aggressive(g) should be theSameInstanceAs g
|
||||
}
|
||||
|
||||
"properly fuse a FusedGraph that has been extended (no AsyncBoundary)" in {
|
||||
val src = Fusing.aggressive(graph(false))
|
||||
val fused = Fusing.aggressive(Source.fromGraph(src).to(Sink.head))
|
||||
verify(fused, modules = 1, downstreams = 6)
|
||||
}
|
||||
|
||||
"properly fuse a FusedGraph that has been extended (with AsyncBoundary)" in {
|
||||
val src = Fusing.aggressive(graph(true))
|
||||
val fused = Fusing.aggressive(Source.fromGraph(src).to(Sink.head))
|
||||
verify(fused, modules = 2, downstreams = 6)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package akka.stream.impl
|
|||
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import akka.stream._
|
||||
import akka.stream.Fusing.aggressive
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.stage._
|
||||
import akka.stream.testkit.Utils.assertAllStagesStopped
|
||||
|
|
@ -66,12 +67,6 @@ class GraphStageLogicSpec extends AkkaSpec with GraphInterpreterSpecKit with Con
|
|||
}
|
||||
}
|
||||
|
||||
class FusedGraph[S <: Shape](ga: GraphAssembly, s: S, a: Attributes = Attributes.none) extends Graph[S, Unit] {
|
||||
override def shape = s
|
||||
override val module = GraphModule(ga, s, a, ga.stages.map(_.module))
|
||||
override def withAttributes(attr: Attributes) = new FusedGraph(ga, s, attr)
|
||||
}
|
||||
|
||||
"A GraphStageLogic" must {
|
||||
|
||||
"emit all things before completing" in assertAllStagesStopped {
|
||||
|
|
@ -84,38 +79,21 @@ class GraphStageLogicSpec extends AkkaSpec with GraphInterpreterSpecKit with Con
|
|||
}
|
||||
|
||||
"emit all things before completing with two fused stages" in assertAllStagesStopped {
|
||||
new Builder {
|
||||
val g = new FusedGraph(
|
||||
builder(emit1234, emit5678)
|
||||
.connect(Upstream, emit1234.in)
|
||||
.connect(emit1234.out, emit5678.in)
|
||||
.connect(emit5678.out, Downstream)
|
||||
.buildAssembly(),
|
||||
FlowShape(emit1234.in, emit5678.out))
|
||||
val g = aggressive(Flow[Int].via(emit1234).via(emit5678))
|
||||
|
||||
Source.empty.via(g).runWith(TestSink.probe)
|
||||
.request(9)
|
||||
.expectNextN(1 to 8)
|
||||
.expectComplete()
|
||||
}
|
||||
Source.empty.via(g).runWith(TestSink.probe)
|
||||
.request(9)
|
||||
.expectNextN(1 to 8)
|
||||
.expectComplete()
|
||||
}
|
||||
|
||||
"emit all things before completing with three fused stages" in assertAllStagesStopped {
|
||||
new Builder {
|
||||
val g = new FusedGraph(
|
||||
builder(emit1234, passThrough, emit5678)
|
||||
.connect(Upstream, emit1234.in)
|
||||
.connect(emit1234.out, passThrough.in)
|
||||
.connect(passThrough.out, emit5678.in)
|
||||
.connect(emit5678.out, Downstream)
|
||||
.buildAssembly(),
|
||||
FlowShape(emit1234.in, emit5678.out))
|
||||
val g = aggressive(Flow[Int].via(emit1234).via(passThrough).via(emit5678))
|
||||
|
||||
Source.empty.via(g).runWith(TestSink.probe)
|
||||
.request(9)
|
||||
.expectNextN(1 to 8)
|
||||
.expectComplete()
|
||||
}
|
||||
Source.empty.via(g).runWith(TestSink.probe)
|
||||
.request(9)
|
||||
.expectNextN(1 to 8)
|
||||
.expectComplete()
|
||||
}
|
||||
|
||||
"invoke lifecycle hooks in the right order" in assertAllStagesStopped {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
*/
|
||||
package akka.stream.scaladsl
|
||||
|
||||
import akka.stream.{ ClosedShape, SourceShape, ActorMaterializer, ActorMaterializerSettings }
|
||||
import akka.stream._
|
||||
import akka.stream.testkit._
|
||||
|
||||
import scala.concurrent.Await
|
||||
|
|
@ -106,5 +106,15 @@ class GraphMatValueSpec extends AkkaSpec {
|
|||
|
||||
}
|
||||
|
||||
"work also when the source’s module is copied" in {
|
||||
val foldFlow: Flow[Int, Int, Future[Int]] = Flow.fromGraph(GraphDSL.create(Sink.fold[Int, Int](0)(_ + _)) {
|
||||
implicit builder ⇒
|
||||
fold ⇒
|
||||
FlowShape(fold.inlet, builder.materializedValue.mapAsync(4)(identity).outlet)
|
||||
})
|
||||
|
||||
Await.result(Source(1 to 10).via(foldFlow).runWith(Sink.head), 3.seconds) should ===(55)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import akka.event.Logging
|
|||
import scala.annotation.tailrec
|
||||
import scala.reflect.{ classTag, ClassTag }
|
||||
import akka.japi.function
|
||||
import akka.stream.impl.StreamLayout._
|
||||
|
||||
/**
|
||||
* Holds attributes which can be used to alter [[akka.stream.scaladsl.Flow]] / [[akka.stream.javadsl.Flow]]
|
||||
|
|
@ -221,6 +222,16 @@ object Attributes {
|
|||
def logLevels(onElement: Logging.LogLevel = Logging.DebugLevel, onFinish: Logging.LogLevel = Logging.DebugLevel, onFailure: Logging.LogLevel = Logging.ErrorLevel) =
|
||||
Attributes(LogLevels(onElement, onFinish, onFailure))
|
||||
|
||||
/**
|
||||
* Compute a name by concatenating all Name attributes that the given module
|
||||
* has, returning the given default value if none are found.
|
||||
*/
|
||||
def extractName(mod: Module, default: String): String = {
|
||||
mod match {
|
||||
case CopiedModule(_, attr, copyOf) ⇒ (attr and copyOf.attributes).nameOrDefault(default)
|
||||
case _ ⇒ mod.attributes.nameOrDefault(default)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import java.{ util ⇒ ju }
|
|||
import scala.collection.immutable
|
||||
import scala.collection.JavaConverters._
|
||||
import akka.stream.impl.StreamLayout._
|
||||
import akka.stream.impl.fusing.{ Fusing ⇒ Impl }
|
||||
import scala.annotation.unchecked.uncheckedVariance
|
||||
|
||||
/**
|
||||
* This class holds some graph transformation functions that can fuse together
|
||||
|
|
@ -32,15 +34,19 @@ object Fusing {
|
|||
* via [[akka.stream.Attributes#AsyncBoundary]].
|
||||
*/
|
||||
def aggressive[S <: Shape, M](g: Graph[S, M]): FusedGraph[S, M] =
|
||||
akka.stream.impl.fusing.Fusing.aggressive(g)
|
||||
g match {
|
||||
case fg: FusedGraph[_, _] ⇒ fg
|
||||
case _ ⇒ Impl.aggressive(g)
|
||||
}
|
||||
|
||||
/**
|
||||
* A fused graph of the right shape, containing a [[FusedModule]] which
|
||||
* holds more information on the operation structure of the contained stream
|
||||
* topology for convenient graph traversal.
|
||||
*/
|
||||
case class FusedGraph[S <: Shape, M](override val module: FusedModule,
|
||||
override val shape: S) extends Graph[S, M] {
|
||||
case class FusedGraph[+S <: Shape @uncheckedVariance, +M](override val module: FusedModule,
|
||||
override val shape: S) extends Graph[S, M] {
|
||||
// the @uncheckedVariance look like a compiler bug ... why does it work in Graph but not here?
|
||||
override def withAttributes(attr: Attributes) = copy(module = module.withAttributes(attr))
|
||||
}
|
||||
|
||||
|
|
@ -54,6 +60,7 @@ object Fusing {
|
|||
final case class StructuralInfo(upstreams: immutable.Map[InPort, OutPort],
|
||||
downstreams: immutable.Map[OutPort, InPort],
|
||||
inOwners: immutable.Map[InPort, Module],
|
||||
outOwners: immutable.Map[OutPort, Module])
|
||||
outOwners: immutable.Map[OutPort, Module],
|
||||
allModules: Set[Module])
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import akka.stream.impl.fusing.GraphModule
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] object StreamLayout {
|
||||
object StreamLayout {
|
||||
|
||||
// compile-time constant
|
||||
final val Debug = false
|
||||
|
|
@ -122,7 +122,7 @@ private[akka] object StreamLayout {
|
|||
override def toString: String = s"Combine($dep1,$dep2)"
|
||||
}
|
||||
case class Atomic(module: Module) extends MaterializedValueNode {
|
||||
override def toString: String = s"Atomic(${module.attributes.nameOrDefault(module.getClass.getName)})"
|
||||
override def toString: String = s"Atomic(${module.attributes.nameOrDefault(module.getClass.getName)}[${module.hashCode}])"
|
||||
}
|
||||
case class Transform(f: Any ⇒ Any, dep: MaterializedValueNode) extends MaterializedValueNode {
|
||||
override def toString: String = s"Transform($dep)"
|
||||
|
|
@ -148,6 +148,7 @@ private[akka] object StreamLayout {
|
|||
final def isBidiFlow: Boolean = (inPorts.size == 2) && (outPorts.size == 2)
|
||||
def isAtomic: Boolean = subModules.isEmpty
|
||||
def isCopied: Boolean = false
|
||||
def isFused: Boolean = false
|
||||
|
||||
/**
|
||||
* Fuses this Module to `that` Module by wiring together `from` and `to`,
|
||||
|
|
@ -192,7 +193,7 @@ private[akka] object StreamLayout {
|
|||
else s"The input port [$to] is not part of the underlying graph.")
|
||||
|
||||
CompositeModule(
|
||||
subModules,
|
||||
if (isSealed) Set(this) else subModules,
|
||||
AmorphousShape(shape.inlets.filterNot(_ == to), shape.outlets.filterNot(_ == from)),
|
||||
downstreams.updated(from, to),
|
||||
upstreams.updated(to, from),
|
||||
|
|
@ -314,7 +315,7 @@ private[akka] object StreamLayout {
|
|||
}
|
||||
|
||||
def subModules: Set[Module]
|
||||
final def isSealed: Boolean = isAtomic || isCopied
|
||||
final def isSealed: Boolean = isAtomic || isCopied || isFused
|
||||
|
||||
def downstreams: Map[OutPort, InPort] = Map.empty
|
||||
def upstreams: Map[InPort, OutPort] = Map.empty
|
||||
|
|
@ -411,6 +412,8 @@ private[akka] object StreamLayout {
|
|||
override val attributes: Attributes,
|
||||
info: Fusing.StructuralInfo) extends Module {
|
||||
|
||||
override def isFused: Boolean = true
|
||||
|
||||
override def replaceShape(s: Shape): Module = {
|
||||
shape.requireSamePortsAs(s)
|
||||
copy(shape = s)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ private[stream] object Fusing {
|
|||
* Perform the fusing of `struct.groups` into GraphModules (leaving them
|
||||
* as they are for non-fusable modules).
|
||||
*/
|
||||
struct.removeInternalWires()
|
||||
struct.breakUpGroupsByDispatcher()
|
||||
val modules = fuse(struct)
|
||||
/*
|
||||
|
|
@ -65,7 +66,7 @@ private[stream] object Fusing {
|
|||
shape,
|
||||
immutable.Map.empty ++ struct.downstreams.asScala,
|
||||
immutable.Map.empty ++ struct.upstreams.asScala,
|
||||
matValue,
|
||||
matValue.head._2,
|
||||
Attributes.none,
|
||||
info)
|
||||
|
||||
|
|
@ -244,26 +245,102 @@ private[stream] object Fusing {
|
|||
inheritedAttributes: Attributes,
|
||||
struct: BuildStructuralInfo,
|
||||
openGroup: ju.Set[Module],
|
||||
indent: String): MaterializedValueNode = {
|
||||
indent: String): List[(Module, MaterializedValueNode)] = {
|
||||
def log(msg: String): Unit = println(indent + msg)
|
||||
val async = m match {
|
||||
case _: GraphStageModule ⇒ m.attributes.contains(AsyncBoundary)
|
||||
case _: GraphModule ⇒ m.attributes.contains(AsyncBoundary)
|
||||
case _ if m.isAtomic ⇒ true // non-GraphStage atomic or has AsyncBoundary
|
||||
case _ ⇒ m.attributes.contains(AsyncBoundary)
|
||||
}
|
||||
if (Debug) log(s"entering ${m.getClass} (async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
|
||||
if (Debug) log(s"entering ${m.getClass} (hash=${m.hashCode}, async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
|
||||
val localGroup =
|
||||
if (async) struct.newGroup(indent)
|
||||
else openGroup
|
||||
|
||||
if (m.isAtomic) {
|
||||
if (Debug) log(s"atomic module $m")
|
||||
struct.addModule(m, localGroup, inheritedAttributes, indent)
|
||||
m match {
|
||||
case gm: GraphModule if !async ⇒
|
||||
// need to dissolve previously fused GraphStages to allow further fusion
|
||||
if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + indent)}")
|
||||
val attributes = inheritedAttributes and m.attributes
|
||||
gm.matValIDs.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
case gm @ GraphModule(_, oldShape, _, mvids) ⇒
|
||||
/*
|
||||
* Importing a GraphModule that has an AsyncBoundary attribute is a little more work:
|
||||
*
|
||||
* - we need to copy all the CopiedModules that are in matValIDs
|
||||
* - we need to rewrite the corresponding MaterializedValueNodes
|
||||
* - we need to match up the new (copied) GraphModule shape with the individual Shape copies
|
||||
* - we need to register the contained modules but take care to not include the internal
|
||||
* wirings into the final result, see also `struct.removeInternalWires()`
|
||||
*/
|
||||
if (Debug) log(s"graph module ${m.toString.replace("\n", "\n" + indent)}")
|
||||
|
||||
// storing the old Shape in arrays for in-place updating as we clone the contained GraphStages
|
||||
val oldIns = oldShape.inlets.toArray
|
||||
val oldOuts = oldShape.outlets.toArray
|
||||
|
||||
val newids = mvids.map {
|
||||
case CopiedModule(shape, attr, copyOf) ⇒
|
||||
val newShape = shape.deepCopy
|
||||
val copy = CopiedModule(newShape, attr, copyOf): Module
|
||||
|
||||
// rewrite shape: first the inlets
|
||||
val oldIn = shape.inlets.iterator
|
||||
val newIn = newShape.inlets.iterator
|
||||
while (oldIn.hasNext) {
|
||||
val o = oldIn.next()
|
||||
val n = newIn.next()
|
||||
findInArray(o, oldIns) match {
|
||||
case -1 ⇒ // nothing to do
|
||||
case idx ⇒ oldIns(idx) = n
|
||||
}
|
||||
}
|
||||
// ... then the outlets
|
||||
val oldOut = shape.outlets.iterator
|
||||
val newOut = newShape.outlets.iterator
|
||||
while (oldOut.hasNext) {
|
||||
val o = oldOut.next()
|
||||
val n = newOut.next()
|
||||
findInArray(o, oldOuts) match {
|
||||
case -1 ⇒ // nothing to do
|
||||
case idx ⇒ oldOuts(idx) = n
|
||||
}
|
||||
}
|
||||
|
||||
// need to add the module so that the structural (internal) wirings can be rewritten as well
|
||||
// but these modules must not be added to any of the groups
|
||||
struct.addModule(copy, new ju.HashSet, inheritedAttributes, indent, shape)
|
||||
struct.registerInteral(newShape, indent)
|
||||
|
||||
copy
|
||||
}
|
||||
val newgm = gm.copy(shape = oldShape.copyFromPorts(oldIns.toList, oldOuts.toList), matValIDs = newids)
|
||||
// make sure to add all the port mappings from old GraphModule Shape to new shape
|
||||
struct.addModule(newgm, localGroup, inheritedAttributes, indent, _oldShape = oldShape)
|
||||
// now compute the list of all materialized value computation updates
|
||||
var result = List.empty[(Module, MaterializedValueNode)]
|
||||
var i = 0
|
||||
while (i < mvids.length) {
|
||||
result ::= mvids(i) -> Atomic(newids(i))
|
||||
i += 1
|
||||
}
|
||||
result ::= m -> Atomic(newgm)
|
||||
result
|
||||
case _ ⇒
|
||||
if (Debug) log(s"atomic module $m")
|
||||
List(m -> struct.addModule(m, localGroup, inheritedAttributes, indent))
|
||||
}
|
||||
} else {
|
||||
val attributes = inheritedAttributes and m.attributes
|
||||
m match {
|
||||
case CopiedModule(shape, _, copyOf) ⇒
|
||||
val ret = descend(copyOf, attributes, struct, localGroup, indent + " ")
|
||||
val ret =
|
||||
descend(copyOf, attributes, struct, localGroup, indent + " ") match {
|
||||
case xs @ (_, mat) :: _ ⇒ (m -> mat) :: xs
|
||||
case _ ⇒ throw new IllegalArgumentException("cannot happen")
|
||||
}
|
||||
struct.rewire(copyOf.shape, shape, indent)
|
||||
ret
|
||||
case _ ⇒
|
||||
|
|
@ -272,15 +349,21 @@ private[stream] object Fusing {
|
|||
struct.enterMatCtx()
|
||||
// now descend into submodules and collect their computations (plus updates to `struct`)
|
||||
val subMat: Predef.Map[Module, MaterializedValueNode] =
|
||||
m.subModules.map(sub ⇒ sub -> descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
m.subModules.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
if (Debug) log(subMat.map(p ⇒ s"${p._1.getClass.getName}[${p._1.hashCode}] -> ${p._2}").mkString("subMat\n " + indent, "\n " + indent, ""))
|
||||
// we need to remove all wirings that this module copied from nested modules so that we
|
||||
// don’t do wirings twice
|
||||
val down = m.subModules.foldLeft(m.downstreams.toSet)((set, m) ⇒ set -- m.downstreams)
|
||||
val oldDownstreams = m match {
|
||||
case f: FusedModule ⇒ f.info.downstreams.toSet
|
||||
case _ ⇒ m.downstreams.toSet
|
||||
}
|
||||
val down = m.subModules.foldLeft(oldDownstreams)((set, m) ⇒ set -- m.downstreams)
|
||||
down.foreach {
|
||||
case (start, end) ⇒ struct.wire(start, end, indent)
|
||||
}
|
||||
// now rewrite the materialized value computation based on the copied modules and their computation nodes
|
||||
val newMat = rewriteMat(subMat, m.materializedValueComputation)
|
||||
val matNodeMapping: ju.Map[MaterializedValueNode, MaterializedValueNode] = new ju.HashMap
|
||||
val newMat = rewriteMat(subMat, m.materializedValueComputation, matNodeMapping)
|
||||
// and finally rewire all MaterializedValueSources to their new computation nodes
|
||||
val matSrcs = struct.exitMatCtx()
|
||||
matSrcs.foreach { c ⇒
|
||||
|
|
@ -288,33 +371,49 @@ private[stream] object Fusing {
|
|||
val ms = c.copyOf match {
|
||||
case g: GraphStageModule ⇒ g.stage.asInstanceOf[MaterializedValueSource[Any]]
|
||||
}
|
||||
if (Debug) require(find(ms.computation, m.materializedValueComputation), s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
|
||||
val replacement = CopiedModule(c.shape, c.attributes, new MaterializedValueSource[Any](newMat, ms.out).module)
|
||||
val mapped = ms.computation match {
|
||||
case Atomic(sub) ⇒ subMat(sub)
|
||||
case other ⇒ matNodeMapping.get(other)
|
||||
}
|
||||
require(mapped != null, s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
|
||||
val newSrc = new MaterializedValueSource[Any](mapped, ms.out)
|
||||
val replacement = CopiedModule(c.shape, c.attributes, newSrc.module)
|
||||
struct.replace(c, replacement, localGroup)
|
||||
}
|
||||
// the result for each level is the materialized value computation
|
||||
newMat
|
||||
List(m -> newMat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def find(node: MaterializedValueNode, inTree: MaterializedValueNode): Boolean =
|
||||
if (node == inTree) true
|
||||
else
|
||||
inTree match {
|
||||
case Atomic(_) ⇒ false
|
||||
case Ignore ⇒ false
|
||||
case Transform(_, dep) ⇒ find(node, dep)
|
||||
case Combine(_, left, right) ⇒ find(node, left) || find(node, right)
|
||||
}
|
||||
@tailrec
|
||||
private def findInArray[T](elem: T, arr: Array[T], idx: Int = 0): Int =
|
||||
if (idx >= arr.length) -1
|
||||
else if (arr(idx) == elem) idx
|
||||
else findInArray(elem, arr, idx + 1)
|
||||
|
||||
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode],
|
||||
mat: MaterializedValueNode): MaterializedValueNode =
|
||||
/**
|
||||
* Given a mapping from old modules to new MaterializedValueNode, rewrite the given
|
||||
* computation while also populating a mapping from old computation nodes to new ones.
|
||||
* That mapping is needed to rewrite the MaterializedValueSource stages later-on in
|
||||
* descend().
|
||||
*/
|
||||
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode], mat: MaterializedValueNode,
|
||||
mapping: ju.Map[MaterializedValueNode, MaterializedValueNode]): MaterializedValueNode =
|
||||
mat match {
|
||||
case Atomic(sub) ⇒ subMat(sub)
|
||||
case Combine(f, left, right) ⇒ Combine(f, rewriteMat(subMat, left), rewriteMat(subMat, right))
|
||||
case Transform(f, dep) ⇒ Transform(f, rewriteMat(subMat, dep))
|
||||
case Ignore ⇒ Ignore
|
||||
case Atomic(sub) ⇒
|
||||
val ret = subMat(sub)
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Combine(f, left, right) ⇒
|
||||
val ret = Combine(f, rewriteMat(subMat, left, mapping), rewriteMat(subMat, right, mapping))
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Transform(f, dep) ⇒
|
||||
val ret = Transform(f, rewriteMat(subMat, dep, mapping))
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Ignore ⇒ Ignore
|
||||
}
|
||||
|
||||
private implicit class NonNull[T](val x: T) extends AnyVal {
|
||||
|
|
@ -335,7 +434,8 @@ private[stream] object Fusing {
|
|||
immutable.Map.empty ++ upstreams.asScala,
|
||||
immutable.Map.empty ++ downstreams.asScala,
|
||||
immutable.Map.empty ++ inOwners.asScala,
|
||||
immutable.Map.empty ++ outOwners.asScala)
|
||||
immutable.Map.empty ++ outOwners.asScala,
|
||||
Set.empty ++ modules.asScala)
|
||||
|
||||
/**
|
||||
* the set of all contained modules
|
||||
|
|
@ -443,6 +543,34 @@ private[stream] object Fusing {
|
|||
*/
|
||||
val outOwners: ju.Map[OutPort, Module] = new ju.HashMap
|
||||
|
||||
/**
|
||||
* List of internal wirings of GraphModules that were incorporated.
|
||||
*/
|
||||
val internalOuts: ju.Set[OutPort] = new ju.HashSet
|
||||
|
||||
/**
|
||||
* Register the outlets of the given Shape as sources for internal
|
||||
* connections within imported (and not dissolved) GraphModules.
|
||||
* See also the comment in addModule where this is partially undone.
|
||||
*/
|
||||
def registerInteral(s: Shape, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"registerInternals(${s.outlets.map(hash)})")
|
||||
internalOuts.addAll(s.outlets.asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove wirings that belong to the fused stages contained in GraphModules
|
||||
* that were incorporated in this fusing run.
|
||||
*/
|
||||
def removeInternalWires(): Unit = {
|
||||
val it = internalOuts.iterator()
|
||||
while (it.hasNext) {
|
||||
val out = it.next()
|
||||
val in = downstreams.remove(out)
|
||||
if (in != null) upstreams.remove(in)
|
||||
}
|
||||
}
|
||||
|
||||
def dump(): Unit = {
|
||||
println("StructuralInfo:")
|
||||
println(" newIns:")
|
||||
|
|
@ -467,13 +595,17 @@ private[stream] object Fusing {
|
|||
/**
|
||||
* Add a module to the given group, performing normalization (i.e. giving it a unique port identity).
|
||||
*/
|
||||
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String): Atomic = {
|
||||
val copy = CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
|
||||
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(m.shape)}")
|
||||
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String,
|
||||
_oldShape: Shape = null): Atomic = {
|
||||
val copy =
|
||||
if (_oldShape == null) CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
|
||||
else m
|
||||
val oldShape = if (_oldShape == null) m.shape else _oldShape
|
||||
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}")
|
||||
group.add(copy)
|
||||
modules.add(copy)
|
||||
copy.shape.outlets.foreach(o ⇒ outGroup.put(o, group))
|
||||
val orig1 = m.shape.inlets.iterator
|
||||
val orig1 = oldShape.inlets.iterator
|
||||
val mapd1 = copy.shape.inlets.iterator
|
||||
while (orig1.hasNext) {
|
||||
val orig = orig1.next()
|
||||
|
|
@ -481,7 +613,7 @@ private[stream] object Fusing {
|
|||
addMapping(orig, mapd, newIns)
|
||||
inOwners.put(mapd, copy)
|
||||
}
|
||||
val orig2 = m.shape.outlets.iterator
|
||||
val orig2 = oldShape.outlets.iterator
|
||||
val mapd2 = copy.shape.outlets.iterator
|
||||
while (orig2.hasNext) {
|
||||
val orig = orig2.next()
|
||||
|
|
@ -489,8 +621,24 @@ private[stream] object Fusing {
|
|||
addMapping(orig, mapd, newOuts)
|
||||
outOwners.put(mapd, copy)
|
||||
}
|
||||
copy.copyOf match {
|
||||
case GraphStageModule(_, _, _: MaterializedValueSource[_]) ⇒ pushMatSrc(copy)
|
||||
/*
|
||||
* In descend() we add internalOuts entries for all shapes that belong to stages that
|
||||
* are part of a GraphModule that is not dissolved. This includes the exposed Outlets,
|
||||
* which of course are external and thus need to be removed again from the internalOuts
|
||||
* set.
|
||||
*/
|
||||
if (m.isInstanceOf[GraphModule]) internalOuts.removeAll(m.shape.outlets.asJava)
|
||||
copy match {
|
||||
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c)
|
||||
case GraphModule(_, _, _, mvids) ⇒
|
||||
var i = 0
|
||||
while (i < mvids.length) {
|
||||
mvids(i) match {
|
||||
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c)
|
||||
case _ ⇒
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
case _ ⇒
|
||||
}
|
||||
Atomic(copy)
|
||||
|
|
@ -502,8 +650,8 @@ private[stream] object Fusing {
|
|||
*/
|
||||
def wire(out: OutPort, in: InPort, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})")
|
||||
val newOut = removeMapping(out, newOuts) nonNull out.toString
|
||||
val newIn = removeMapping(in, newIns) nonNull in.toString
|
||||
val newOut = removeMapping(out, newOuts) nonNull s"$out (${hash(out)})"
|
||||
val newIn = removeMapping(in, newIns) nonNull s"$in (${hash(in)})"
|
||||
downstreams.put(newOut, newIn)
|
||||
upstreams.put(newIn, newOut)
|
||||
}
|
||||
|
|
@ -514,10 +662,10 @@ private[stream] object Fusing {
|
|||
def rewire(oldShape: Shape, newShape: Shape, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}")
|
||||
oldShape.inlets.iterator.zip(newShape.inlets.iterator).foreach {
|
||||
case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull oldIn.toString, newIns)
|
||||
case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull s"$oldIn (${hash(oldIn)})", newIns)
|
||||
}
|
||||
oldShape.outlets.iterator.zip(newShape.outlets.iterator).foreach {
|
||||
case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull oldOut.toString, newOuts)
|
||||
case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull s"$oldOut (${hash(oldOut)})", newOuts)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -534,6 +682,9 @@ private[stream] object Fusing {
|
|||
old.map(o ⇒ newOuts.get(o).head.outlet)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine whether the given CopiedModule has an AsyncBoundary attribute.
|
||||
*/
|
||||
private def isAsync(m: Module): Boolean = m match {
|
||||
case CopiedModule(_, inherited, orig) ⇒
|
||||
val attr = inherited and orig.attributes
|
||||
|
|
@ -550,6 +701,9 @@ private[stream] object Fusing {
|
|||
case x ⇒ x.attributes.get[ActorAttributes.Dispatcher]
|
||||
}
|
||||
|
||||
/**
|
||||
* See through copied modules to the “real” module.
|
||||
*/
|
||||
private def realModule(m: Module): Module = m match {
|
||||
case CopiedModule(_, _, of) ⇒ realModule(of)
|
||||
case other ⇒ other
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue