make fused graphs fusable
This commit is contained in:
parent
3d20915cf4
commit
e4f31b66c3
7 changed files with 308 additions and 95 deletions
|
|
@ -55,6 +55,7 @@ private[stream] object Fusing {
|
|||
* Perform the fusing of `struct.groups` into GraphModules (leaving them
|
||||
* as they are for non-fusable modules).
|
||||
*/
|
||||
struct.removeInternalWires()
|
||||
struct.breakUpGroupsByDispatcher()
|
||||
val modules = fuse(struct)
|
||||
/*
|
||||
|
|
@ -65,7 +66,7 @@ private[stream] object Fusing {
|
|||
shape,
|
||||
immutable.Map.empty ++ struct.downstreams.asScala,
|
||||
immutable.Map.empty ++ struct.upstreams.asScala,
|
||||
matValue,
|
||||
matValue.head._2,
|
||||
Attributes.none,
|
||||
info)
|
||||
|
||||
|
|
@ -244,26 +245,102 @@ private[stream] object Fusing {
|
|||
inheritedAttributes: Attributes,
|
||||
struct: BuildStructuralInfo,
|
||||
openGroup: ju.Set[Module],
|
||||
indent: String): MaterializedValueNode = {
|
||||
indent: String): List[(Module, MaterializedValueNode)] = {
|
||||
def log(msg: String): Unit = println(indent + msg)
|
||||
val async = m match {
|
||||
case _: GraphStageModule ⇒ m.attributes.contains(AsyncBoundary)
|
||||
case _: GraphModule ⇒ m.attributes.contains(AsyncBoundary)
|
||||
case _ if m.isAtomic ⇒ true // non-GraphStage atomic or has AsyncBoundary
|
||||
case _ ⇒ m.attributes.contains(AsyncBoundary)
|
||||
}
|
||||
if (Debug) log(s"entering ${m.getClass} (async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
|
||||
if (Debug) log(s"entering ${m.getClass} (hash=${m.hashCode}, async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
|
||||
val localGroup =
|
||||
if (async) struct.newGroup(indent)
|
||||
else openGroup
|
||||
|
||||
if (m.isAtomic) {
|
||||
if (Debug) log(s"atomic module $m")
|
||||
struct.addModule(m, localGroup, inheritedAttributes, indent)
|
||||
m match {
|
||||
case gm: GraphModule if !async ⇒
|
||||
// need to dissolve previously fused GraphStages to allow further fusion
|
||||
if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + indent)}")
|
||||
val attributes = inheritedAttributes and m.attributes
|
||||
gm.matValIDs.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
case gm @ GraphModule(_, oldShape, _, mvids) ⇒
|
||||
/*
|
||||
* Importing a GraphModule that has an AsyncBoundary attribute is a little more work:
|
||||
*
|
||||
* - we need to copy all the CopiedModules that are in matValIDs
|
||||
* - we need to rewrite the corresponding MaterializedValueNodes
|
||||
* - we need to match up the new (copied) GraphModule shape with the individual Shape copies
|
||||
* - we need to register the contained modules but take care to not include the internal
|
||||
* wirings into the final result, see also `struct.removeInternalWires()`
|
||||
*/
|
||||
if (Debug) log(s"graph module ${m.toString.replace("\n", "\n" + indent)}")
|
||||
|
||||
// storing the old Shape in arrays for in-place updating as we clone the contained GraphStages
|
||||
val oldIns = oldShape.inlets.toArray
|
||||
val oldOuts = oldShape.outlets.toArray
|
||||
|
||||
val newids = mvids.map {
|
||||
case CopiedModule(shape, attr, copyOf) ⇒
|
||||
val newShape = shape.deepCopy
|
||||
val copy = CopiedModule(newShape, attr, copyOf): Module
|
||||
|
||||
// rewrite shape: first the inlets
|
||||
val oldIn = shape.inlets.iterator
|
||||
val newIn = newShape.inlets.iterator
|
||||
while (oldIn.hasNext) {
|
||||
val o = oldIn.next()
|
||||
val n = newIn.next()
|
||||
findInArray(o, oldIns) match {
|
||||
case -1 ⇒ // nothing to do
|
||||
case idx ⇒ oldIns(idx) = n
|
||||
}
|
||||
}
|
||||
// ... then the outlets
|
||||
val oldOut = shape.outlets.iterator
|
||||
val newOut = newShape.outlets.iterator
|
||||
while (oldOut.hasNext) {
|
||||
val o = oldOut.next()
|
||||
val n = newOut.next()
|
||||
findInArray(o, oldOuts) match {
|
||||
case -1 ⇒ // nothing to do
|
||||
case idx ⇒ oldOuts(idx) = n
|
||||
}
|
||||
}
|
||||
|
||||
// need to add the module so that the structural (internal) wirings can be rewritten as well
|
||||
// but these modules must not be added to any of the groups
|
||||
struct.addModule(copy, new ju.HashSet, inheritedAttributes, indent, shape)
|
||||
struct.registerInteral(newShape, indent)
|
||||
|
||||
copy
|
||||
}
|
||||
val newgm = gm.copy(shape = oldShape.copyFromPorts(oldIns.toList, oldOuts.toList), matValIDs = newids)
|
||||
// make sure to add all the port mappings from old GraphModule Shape to new shape
|
||||
struct.addModule(newgm, localGroup, inheritedAttributes, indent, _oldShape = oldShape)
|
||||
// now compute the list of all materialized value computation updates
|
||||
var result = List.empty[(Module, MaterializedValueNode)]
|
||||
var i = 0
|
||||
while (i < mvids.length) {
|
||||
result ::= mvids(i) -> Atomic(newids(i))
|
||||
i += 1
|
||||
}
|
||||
result ::= m -> Atomic(newgm)
|
||||
result
|
||||
case _ ⇒
|
||||
if (Debug) log(s"atomic module $m")
|
||||
List(m -> struct.addModule(m, localGroup, inheritedAttributes, indent))
|
||||
}
|
||||
} else {
|
||||
val attributes = inheritedAttributes and m.attributes
|
||||
m match {
|
||||
case CopiedModule(shape, _, copyOf) ⇒
|
||||
val ret = descend(copyOf, attributes, struct, localGroup, indent + " ")
|
||||
val ret =
|
||||
descend(copyOf, attributes, struct, localGroup, indent + " ") match {
|
||||
case xs @ (_, mat) :: _ ⇒ (m -> mat) :: xs
|
||||
case _ ⇒ throw new IllegalArgumentException("cannot happen")
|
||||
}
|
||||
struct.rewire(copyOf.shape, shape, indent)
|
||||
ret
|
||||
case _ ⇒
|
||||
|
|
@ -272,15 +349,21 @@ private[stream] object Fusing {
|
|||
struct.enterMatCtx()
|
||||
// now descend into submodules and collect their computations (plus updates to `struct`)
|
||||
val subMat: Predef.Map[Module, MaterializedValueNode] =
|
||||
m.subModules.map(sub ⇒ sub -> descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
m.subModules.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
|
||||
if (Debug) log(subMat.map(p ⇒ s"${p._1.getClass.getName}[${p._1.hashCode}] -> ${p._2}").mkString("subMat\n " + indent, "\n " + indent, ""))
|
||||
// we need to remove all wirings that this module copied from nested modules so that we
|
||||
// don’t do wirings twice
|
||||
val down = m.subModules.foldLeft(m.downstreams.toSet)((set, m) ⇒ set -- m.downstreams)
|
||||
val oldDownstreams = m match {
|
||||
case f: FusedModule ⇒ f.info.downstreams.toSet
|
||||
case _ ⇒ m.downstreams.toSet
|
||||
}
|
||||
val down = m.subModules.foldLeft(oldDownstreams)((set, m) ⇒ set -- m.downstreams)
|
||||
down.foreach {
|
||||
case (start, end) ⇒ struct.wire(start, end, indent)
|
||||
}
|
||||
// now rewrite the materialized value computation based on the copied modules and their computation nodes
|
||||
val newMat = rewriteMat(subMat, m.materializedValueComputation)
|
||||
val matNodeMapping: ju.Map[MaterializedValueNode, MaterializedValueNode] = new ju.HashMap
|
||||
val newMat = rewriteMat(subMat, m.materializedValueComputation, matNodeMapping)
|
||||
// and finally rewire all MaterializedValueSources to their new computation nodes
|
||||
val matSrcs = struct.exitMatCtx()
|
||||
matSrcs.foreach { c ⇒
|
||||
|
|
@ -288,33 +371,49 @@ private[stream] object Fusing {
|
|||
val ms = c.copyOf match {
|
||||
case g: GraphStageModule ⇒ g.stage.asInstanceOf[MaterializedValueSource[Any]]
|
||||
}
|
||||
if (Debug) require(find(ms.computation, m.materializedValueComputation), s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
|
||||
val replacement = CopiedModule(c.shape, c.attributes, new MaterializedValueSource[Any](newMat, ms.out).module)
|
||||
val mapped = ms.computation match {
|
||||
case Atomic(sub) ⇒ subMat(sub)
|
||||
case other ⇒ matNodeMapping.get(other)
|
||||
}
|
||||
require(mapped != null, s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
|
||||
val newSrc = new MaterializedValueSource[Any](mapped, ms.out)
|
||||
val replacement = CopiedModule(c.shape, c.attributes, newSrc.module)
|
||||
struct.replace(c, replacement, localGroup)
|
||||
}
|
||||
// the result for each level is the materialized value computation
|
||||
newMat
|
||||
List(m -> newMat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def find(node: MaterializedValueNode, inTree: MaterializedValueNode): Boolean =
|
||||
if (node == inTree) true
|
||||
else
|
||||
inTree match {
|
||||
case Atomic(_) ⇒ false
|
||||
case Ignore ⇒ false
|
||||
case Transform(_, dep) ⇒ find(node, dep)
|
||||
case Combine(_, left, right) ⇒ find(node, left) || find(node, right)
|
||||
}
|
||||
@tailrec
|
||||
private def findInArray[T](elem: T, arr: Array[T], idx: Int = 0): Int =
|
||||
if (idx >= arr.length) -1
|
||||
else if (arr(idx) == elem) idx
|
||||
else findInArray(elem, arr, idx + 1)
|
||||
|
||||
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode],
|
||||
mat: MaterializedValueNode): MaterializedValueNode =
|
||||
/**
|
||||
* Given a mapping from old modules to new MaterializedValueNode, rewrite the given
|
||||
* computation while also populating a mapping from old computation nodes to new ones.
|
||||
* That mapping is needed to rewrite the MaterializedValueSource stages later-on in
|
||||
* descend().
|
||||
*/
|
||||
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode], mat: MaterializedValueNode,
|
||||
mapping: ju.Map[MaterializedValueNode, MaterializedValueNode]): MaterializedValueNode =
|
||||
mat match {
|
||||
case Atomic(sub) ⇒ subMat(sub)
|
||||
case Combine(f, left, right) ⇒ Combine(f, rewriteMat(subMat, left), rewriteMat(subMat, right))
|
||||
case Transform(f, dep) ⇒ Transform(f, rewriteMat(subMat, dep))
|
||||
case Ignore ⇒ Ignore
|
||||
case Atomic(sub) ⇒
|
||||
val ret = subMat(sub)
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Combine(f, left, right) ⇒
|
||||
val ret = Combine(f, rewriteMat(subMat, left, mapping), rewriteMat(subMat, right, mapping))
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Transform(f, dep) ⇒
|
||||
val ret = Transform(f, rewriteMat(subMat, dep, mapping))
|
||||
mapping.put(mat, ret)
|
||||
ret
|
||||
case Ignore ⇒ Ignore
|
||||
}
|
||||
|
||||
private implicit class NonNull[T](val x: T) extends AnyVal {
|
||||
|
|
@ -335,7 +434,8 @@ private[stream] object Fusing {
|
|||
immutable.Map.empty ++ upstreams.asScala,
|
||||
immutable.Map.empty ++ downstreams.asScala,
|
||||
immutable.Map.empty ++ inOwners.asScala,
|
||||
immutable.Map.empty ++ outOwners.asScala)
|
||||
immutable.Map.empty ++ outOwners.asScala,
|
||||
Set.empty ++ modules.asScala)
|
||||
|
||||
/**
|
||||
* the set of all contained modules
|
||||
|
|
@ -443,6 +543,34 @@ private[stream] object Fusing {
|
|||
*/
|
||||
val outOwners: ju.Map[OutPort, Module] = new ju.HashMap
|
||||
|
||||
/**
|
||||
* List of internal wirings of GraphModules that were incorporated.
|
||||
*/
|
||||
val internalOuts: ju.Set[OutPort] = new ju.HashSet
|
||||
|
||||
/**
|
||||
* Register the outlets of the given Shape as sources for internal
|
||||
* connections within imported (and not dissolved) GraphModules.
|
||||
* See also the comment in addModule where this is partially undone.
|
||||
*/
|
||||
def registerInteral(s: Shape, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"registerInternals(${s.outlets.map(hash)})")
|
||||
internalOuts.addAll(s.outlets.asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove wirings that belong to the fused stages contained in GraphModules
|
||||
* that were incorporated in this fusing run.
|
||||
*/
|
||||
def removeInternalWires(): Unit = {
|
||||
val it = internalOuts.iterator()
|
||||
while (it.hasNext) {
|
||||
val out = it.next()
|
||||
val in = downstreams.remove(out)
|
||||
if (in != null) upstreams.remove(in)
|
||||
}
|
||||
}
|
||||
|
||||
def dump(): Unit = {
|
||||
println("StructuralInfo:")
|
||||
println(" newIns:")
|
||||
|
|
@ -467,13 +595,17 @@ private[stream] object Fusing {
|
|||
/**
|
||||
* Add a module to the given group, performing normalization (i.e. giving it a unique port identity).
|
||||
*/
|
||||
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String): Atomic = {
|
||||
val copy = CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
|
||||
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(m.shape)}")
|
||||
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String,
|
||||
_oldShape: Shape = null): Atomic = {
|
||||
val copy =
|
||||
if (_oldShape == null) CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
|
||||
else m
|
||||
val oldShape = if (_oldShape == null) m.shape else _oldShape
|
||||
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}")
|
||||
group.add(copy)
|
||||
modules.add(copy)
|
||||
copy.shape.outlets.foreach(o ⇒ outGroup.put(o, group))
|
||||
val orig1 = m.shape.inlets.iterator
|
||||
val orig1 = oldShape.inlets.iterator
|
||||
val mapd1 = copy.shape.inlets.iterator
|
||||
while (orig1.hasNext) {
|
||||
val orig = orig1.next()
|
||||
|
|
@ -481,7 +613,7 @@ private[stream] object Fusing {
|
|||
addMapping(orig, mapd, newIns)
|
||||
inOwners.put(mapd, copy)
|
||||
}
|
||||
val orig2 = m.shape.outlets.iterator
|
||||
val orig2 = oldShape.outlets.iterator
|
||||
val mapd2 = copy.shape.outlets.iterator
|
||||
while (orig2.hasNext) {
|
||||
val orig = orig2.next()
|
||||
|
|
@ -489,8 +621,24 @@ private[stream] object Fusing {
|
|||
addMapping(orig, mapd, newOuts)
|
||||
outOwners.put(mapd, copy)
|
||||
}
|
||||
copy.copyOf match {
|
||||
case GraphStageModule(_, _, _: MaterializedValueSource[_]) ⇒ pushMatSrc(copy)
|
||||
/*
|
||||
* In descend() we add internalOuts entries for all shapes that belong to stages that
|
||||
* are part of a GraphModule that is not dissolved. This includes the exposed Outlets,
|
||||
* which of course are external and thus need to be removed again from the internalOuts
|
||||
* set.
|
||||
*/
|
||||
if (m.isInstanceOf[GraphModule]) internalOuts.removeAll(m.shape.outlets.asJava)
|
||||
copy match {
|
||||
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c)
|
||||
case GraphModule(_, _, _, mvids) ⇒
|
||||
var i = 0
|
||||
while (i < mvids.length) {
|
||||
mvids(i) match {
|
||||
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c)
|
||||
case _ ⇒
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
case _ ⇒
|
||||
}
|
||||
Atomic(copy)
|
||||
|
|
@ -502,8 +650,8 @@ private[stream] object Fusing {
|
|||
*/
|
||||
def wire(out: OutPort, in: InPort, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})")
|
||||
val newOut = removeMapping(out, newOuts) nonNull out.toString
|
||||
val newIn = removeMapping(in, newIns) nonNull in.toString
|
||||
val newOut = removeMapping(out, newOuts) nonNull s"$out (${hash(out)})"
|
||||
val newIn = removeMapping(in, newIns) nonNull s"$in (${hash(in)})"
|
||||
downstreams.put(newOut, newIn)
|
||||
upstreams.put(newIn, newOut)
|
||||
}
|
||||
|
|
@ -514,10 +662,10 @@ private[stream] object Fusing {
|
|||
def rewire(oldShape: Shape, newShape: Shape, indent: String): Unit = {
|
||||
if (Debug) println(indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}")
|
||||
oldShape.inlets.iterator.zip(newShape.inlets.iterator).foreach {
|
||||
case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull oldIn.toString, newIns)
|
||||
case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull s"$oldIn (${hash(oldIn)})", newIns)
|
||||
}
|
||||
oldShape.outlets.iterator.zip(newShape.outlets.iterator).foreach {
|
||||
case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull oldOut.toString, newOuts)
|
||||
case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull s"$oldOut (${hash(oldOut)})", newOuts)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -534,6 +682,9 @@ private[stream] object Fusing {
|
|||
old.map(o ⇒ newOuts.get(o).head.outlet)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine whether the given CopiedModule has an AsyncBoundary attribute.
|
||||
*/
|
||||
private def isAsync(m: Module): Boolean = m match {
|
||||
case CopiedModule(_, inherited, orig) ⇒
|
||||
val attr = inherited and orig.attributes
|
||||
|
|
@ -550,6 +701,9 @@ private[stream] object Fusing {
|
|||
case x ⇒ x.attributes.get[ActorAttributes.Dispatcher]
|
||||
}
|
||||
|
||||
/**
|
||||
* See through copied modules to the “real” module.
|
||||
*/
|
||||
private def realModule(m: Module): Module = m match {
|
||||
case CopiedModule(_, _, of) ⇒ realModule(of)
|
||||
case other ⇒ other
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue