make fused graphs fusable

This commit is contained in:
Roland Kuhn 2015-12-15 16:44:48 +01:00
parent 3d20915cf4
commit e4f31b66c3
7 changed files with 308 additions and 95 deletions

View file

@ -55,6 +55,7 @@ private[stream] object Fusing {
* Perform the fusing of `struct.groups` into GraphModules (leaving them
* as they are for non-fusable modules).
*/
struct.removeInternalWires()
struct.breakUpGroupsByDispatcher()
val modules = fuse(struct)
/*
@ -65,7 +66,7 @@ private[stream] object Fusing {
shape,
immutable.Map.empty ++ struct.downstreams.asScala,
immutable.Map.empty ++ struct.upstreams.asScala,
matValue,
matValue.head._2,
Attributes.none,
info)
@ -244,26 +245,102 @@ private[stream] object Fusing {
inheritedAttributes: Attributes,
struct: BuildStructuralInfo,
openGroup: ju.Set[Module],
indent: String): MaterializedValueNode = {
indent: String): List[(Module, MaterializedValueNode)] = {
def log(msg: String): Unit = println(indent + msg)
val async = m match {
case _: GraphStageModule m.attributes.contains(AsyncBoundary)
case _: GraphModule m.attributes.contains(AsyncBoundary)
case _ if m.isAtomic true // non-GraphStage atomic or has AsyncBoundary
case _ m.attributes.contains(AsyncBoundary)
}
if (Debug) log(s"entering ${m.getClass} (async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
if (Debug) log(s"entering ${m.getClass} (hash=${m.hashCode}, async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})")
val localGroup =
if (async) struct.newGroup(indent)
else openGroup
if (m.isAtomic) {
if (Debug) log(s"atomic module $m")
struct.addModule(m, localGroup, inheritedAttributes, indent)
m match {
case gm: GraphModule if !async
// need to dissolve previously fused GraphStages to allow further fusion
if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + indent)}")
val attributes = inheritedAttributes and m.attributes
gm.matValIDs.flatMap(sub descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
case gm @ GraphModule(_, oldShape, _, mvids)
/*
* Importing a GraphModule that has an AsyncBoundary attribute is a little more work:
*
* - we need to copy all the CopiedModules that are in matValIDs
* - we need to rewrite the corresponding MaterializedValueNodes
* - we need to match up the new (copied) GraphModule shape with the individual Shape copies
* - we need to register the contained modules but take care to not include the internal
* wirings into the final result, see also `struct.removeInternalWires()`
*/
if (Debug) log(s"graph module ${m.toString.replace("\n", "\n" + indent)}")
// storing the old Shape in arrays for in-place updating as we clone the contained GraphStages
val oldIns = oldShape.inlets.toArray
val oldOuts = oldShape.outlets.toArray
val newids = mvids.map {
case CopiedModule(shape, attr, copyOf)
val newShape = shape.deepCopy
val copy = CopiedModule(newShape, attr, copyOf): Module
// rewrite shape: first the inlets
val oldIn = shape.inlets.iterator
val newIn = newShape.inlets.iterator
while (oldIn.hasNext) {
val o = oldIn.next()
val n = newIn.next()
findInArray(o, oldIns) match {
case -1 // nothing to do
case idx oldIns(idx) = n
}
}
// ... then the outlets
val oldOut = shape.outlets.iterator
val newOut = newShape.outlets.iterator
while (oldOut.hasNext) {
val o = oldOut.next()
val n = newOut.next()
findInArray(o, oldOuts) match {
case -1 // nothing to do
case idx oldOuts(idx) = n
}
}
// need to add the module so that the structural (internal) wirings can be rewritten as well
// but these modules must not be added to any of the groups
struct.addModule(copy, new ju.HashSet, inheritedAttributes, indent, shape)
struct.registerInteral(newShape, indent)
copy
}
val newgm = gm.copy(shape = oldShape.copyFromPorts(oldIns.toList, oldOuts.toList), matValIDs = newids)
// make sure to add all the port mappings from old GraphModule Shape to new shape
struct.addModule(newgm, localGroup, inheritedAttributes, indent, _oldShape = oldShape)
// now compute the list of all materialized value computation updates
var result = List.empty[(Module, MaterializedValueNode)]
var i = 0
while (i < mvids.length) {
result ::= mvids(i) -> Atomic(newids(i))
i += 1
}
result ::= m -> Atomic(newgm)
result
case _
if (Debug) log(s"atomic module $m")
List(m -> struct.addModule(m, localGroup, inheritedAttributes, indent))
}
} else {
val attributes = inheritedAttributes and m.attributes
m match {
case CopiedModule(shape, _, copyOf)
val ret = descend(copyOf, attributes, struct, localGroup, indent + " ")
val ret =
descend(copyOf, attributes, struct, localGroup, indent + " ") match {
case xs @ (_, mat) :: _ (m -> mat) :: xs
case _ throw new IllegalArgumentException("cannot happen")
}
struct.rewire(copyOf.shape, shape, indent)
ret
case _
@ -272,15 +349,21 @@ private[stream] object Fusing {
struct.enterMatCtx()
// now descend into submodules and collect their computations (plus updates to `struct`)
val subMat: Predef.Map[Module, MaterializedValueNode] =
m.subModules.map(sub sub -> descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
m.subModules.flatMap(sub descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut)
if (Debug) log(subMat.map(p s"${p._1.getClass.getName}[${p._1.hashCode}] -> ${p._2}").mkString("subMat\n " + indent, "\n " + indent, ""))
// we need to remove all wirings that this module copied from nested modules so that we
// dont do wirings twice
val down = m.subModules.foldLeft(m.downstreams.toSet)((set, m) set -- m.downstreams)
val oldDownstreams = m match {
case f: FusedModule f.info.downstreams.toSet
case _ m.downstreams.toSet
}
val down = m.subModules.foldLeft(oldDownstreams)((set, m) set -- m.downstreams)
down.foreach {
case (start, end) struct.wire(start, end, indent)
}
// now rewrite the materialized value computation based on the copied modules and their computation nodes
val newMat = rewriteMat(subMat, m.materializedValueComputation)
val matNodeMapping: ju.Map[MaterializedValueNode, MaterializedValueNode] = new ju.HashMap
val newMat = rewriteMat(subMat, m.materializedValueComputation, matNodeMapping)
// and finally rewire all MaterializedValueSources to their new computation nodes
val matSrcs = struct.exitMatCtx()
matSrcs.foreach { c
@ -288,33 +371,49 @@ private[stream] object Fusing {
val ms = c.copyOf match {
case g: GraphStageModule g.stage.asInstanceOf[MaterializedValueSource[Any]]
}
if (Debug) require(find(ms.computation, m.materializedValueComputation), s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
val replacement = CopiedModule(c.shape, c.attributes, new MaterializedValueSource[Any](newMat, ms.out).module)
val mapped = ms.computation match {
case Atomic(sub) subMat(sub)
case other matNodeMapping.get(other)
}
require(mapped != null, s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}")
val newSrc = new MaterializedValueSource[Any](mapped, ms.out)
val replacement = CopiedModule(c.shape, c.attributes, newSrc.module)
struct.replace(c, replacement, localGroup)
}
// the result for each level is the materialized value computation
newMat
List(m -> newMat)
}
}
}
private def find(node: MaterializedValueNode, inTree: MaterializedValueNode): Boolean =
if (node == inTree) true
else
inTree match {
case Atomic(_) false
case Ignore false
case Transform(_, dep) find(node, dep)
case Combine(_, left, right) find(node, left) || find(node, right)
}
@tailrec
private def findInArray[T](elem: T, arr: Array[T], idx: Int = 0): Int =
if (idx >= arr.length) -1
else if (arr(idx) == elem) idx
else findInArray(elem, arr, idx + 1)
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode],
mat: MaterializedValueNode): MaterializedValueNode =
/**
* Given a mapping from old modules to new MaterializedValueNode, rewrite the given
* computation while also populating a mapping from old computation nodes to new ones.
* That mapping is needed to rewrite the MaterializedValueSource stages later-on in
* descend().
*/
private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode], mat: MaterializedValueNode,
mapping: ju.Map[MaterializedValueNode, MaterializedValueNode]): MaterializedValueNode =
mat match {
case Atomic(sub) subMat(sub)
case Combine(f, left, right) Combine(f, rewriteMat(subMat, left), rewriteMat(subMat, right))
case Transform(f, dep) Transform(f, rewriteMat(subMat, dep))
case Ignore Ignore
case Atomic(sub)
val ret = subMat(sub)
mapping.put(mat, ret)
ret
case Combine(f, left, right)
val ret = Combine(f, rewriteMat(subMat, left, mapping), rewriteMat(subMat, right, mapping))
mapping.put(mat, ret)
ret
case Transform(f, dep)
val ret = Transform(f, rewriteMat(subMat, dep, mapping))
mapping.put(mat, ret)
ret
case Ignore Ignore
}
private implicit class NonNull[T](val x: T) extends AnyVal {
@ -335,7 +434,8 @@ private[stream] object Fusing {
immutable.Map.empty ++ upstreams.asScala,
immutable.Map.empty ++ downstreams.asScala,
immutable.Map.empty ++ inOwners.asScala,
immutable.Map.empty ++ outOwners.asScala)
immutable.Map.empty ++ outOwners.asScala,
Set.empty ++ modules.asScala)
/**
* the set of all contained modules
@ -443,6 +543,34 @@ private[stream] object Fusing {
*/
val outOwners: ju.Map[OutPort, Module] = new ju.HashMap
/**
* List of internal wirings of GraphModules that were incorporated.
*/
val internalOuts: ju.Set[OutPort] = new ju.HashSet
/**
* Register the outlets of the given Shape as sources for internal
* connections within imported (and not dissolved) GraphModules.
* See also the comment in addModule where this is partially undone.
*/
def registerInteral(s: Shape, indent: String): Unit = {
if (Debug) println(indent + s"registerInternals(${s.outlets.map(hash)})")
internalOuts.addAll(s.outlets.asJava)
}
/**
* Remove wirings that belong to the fused stages contained in GraphModules
* that were incorporated in this fusing run.
*/
def removeInternalWires(): Unit = {
val it = internalOuts.iterator()
while (it.hasNext) {
val out = it.next()
val in = downstreams.remove(out)
if (in != null) upstreams.remove(in)
}
}
def dump(): Unit = {
println("StructuralInfo:")
println(" newIns:")
@ -467,13 +595,17 @@ private[stream] object Fusing {
/**
* Add a module to the given group, performing normalization (i.e. giving it a unique port identity).
*/
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String): Atomic = {
val copy = CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(m.shape)}")
def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String,
_oldShape: Shape = null): Atomic = {
val copy =
if (_oldShape == null) CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m))
else m
val oldShape = if (_oldShape == null) m.shape else _oldShape
if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}")
group.add(copy)
modules.add(copy)
copy.shape.outlets.foreach(o outGroup.put(o, group))
val orig1 = m.shape.inlets.iterator
val orig1 = oldShape.inlets.iterator
val mapd1 = copy.shape.inlets.iterator
while (orig1.hasNext) {
val orig = orig1.next()
@ -481,7 +613,7 @@ private[stream] object Fusing {
addMapping(orig, mapd, newIns)
inOwners.put(mapd, copy)
}
val orig2 = m.shape.outlets.iterator
val orig2 = oldShape.outlets.iterator
val mapd2 = copy.shape.outlets.iterator
while (orig2.hasNext) {
val orig = orig2.next()
@ -489,8 +621,24 @@ private[stream] object Fusing {
addMapping(orig, mapd, newOuts)
outOwners.put(mapd, copy)
}
copy.copyOf match {
case GraphStageModule(_, _, _: MaterializedValueSource[_]) pushMatSrc(copy)
/*
* In descend() we add internalOuts entries for all shapes that belong to stages that
* are part of a GraphModule that is not dissolved. This includes the exposed Outlets,
* which of course are external and thus need to be removed again from the internalOuts
* set.
*/
if (m.isInstanceOf[GraphModule]) internalOuts.removeAll(m.shape.outlets.asJava)
copy match {
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) pushMatSrc(c)
case GraphModule(_, _, _, mvids)
var i = 0
while (i < mvids.length) {
mvids(i) match {
case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) pushMatSrc(c)
case _
}
i += 1
}
case _
}
Atomic(copy)
@ -502,8 +650,8 @@ private[stream] object Fusing {
*/
def wire(out: OutPort, in: InPort, indent: String): Unit = {
if (Debug) println(indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})")
val newOut = removeMapping(out, newOuts) nonNull out.toString
val newIn = removeMapping(in, newIns) nonNull in.toString
val newOut = removeMapping(out, newOuts) nonNull s"$out (${hash(out)})"
val newIn = removeMapping(in, newIns) nonNull s"$in (${hash(in)})"
downstreams.put(newOut, newIn)
upstreams.put(newIn, newOut)
}
@ -514,10 +662,10 @@ private[stream] object Fusing {
def rewire(oldShape: Shape, newShape: Shape, indent: String): Unit = {
if (Debug) println(indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}")
oldShape.inlets.iterator.zip(newShape.inlets.iterator).foreach {
case (oldIn, newIn) addMapping(newIn, removeMapping(oldIn, newIns) nonNull oldIn.toString, newIns)
case (oldIn, newIn) addMapping(newIn, removeMapping(oldIn, newIns) nonNull s"$oldIn (${hash(oldIn)})", newIns)
}
oldShape.outlets.iterator.zip(newShape.outlets.iterator).foreach {
case (oldOut, newOut) addMapping(newOut, removeMapping(oldOut, newOuts) nonNull oldOut.toString, newOuts)
case (oldOut, newOut) addMapping(newOut, removeMapping(oldOut, newOuts) nonNull s"$oldOut (${hash(oldOut)})", newOuts)
}
}
@ -534,6 +682,9 @@ private[stream] object Fusing {
old.map(o newOuts.get(o).head.outlet)
}
/**
* Determine whether the given CopiedModule has an AsyncBoundary attribute.
*/
private def isAsync(m: Module): Boolean = m match {
case CopiedModule(_, inherited, orig)
val attr = inherited and orig.attributes
@ -550,6 +701,9 @@ private[stream] object Fusing {
case x x.attributes.get[ActorAttributes.Dispatcher]
}
/**
* See through copied modules to the real module.
*/
private def realModule(m: Module): Module = m match {
case CopiedModule(_, _, of) realModule(of)
case other other