diff --git a/akka-stream-tests/src/test/scala/akka/stream/FusingSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/FusingSpec.scala index 64134a7682..d8668937eb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/FusingSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/FusingSpec.scala @@ -7,32 +7,82 @@ import akka.stream._ import akka.stream.scaladsl._ import akka.stream.testkit.AkkaSpec import org.scalactic.ConversionCheckedTripleEquals +import akka.stream.Attributes._ +import akka.stream.Fusing.FusedGraph +import scala.annotation.tailrec +import akka.stream.impl.StreamLayout.Module class FusingSpec extends AkkaSpec with ConversionCheckedTripleEquals { + final val Debug = false implicit val materializer = ActorMaterializer() + def graph(async: Boolean) = + Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1) + .alsoTo(Flow[Int].fold(0)(_ + _).to(Sink.head.named("otherSink")).addAttributes(if (async) Attributes.asyncBoundary else Attributes.none)) + .via(Flow[Int].fold(1)(_ + _).named("mainSink")) + + def singlePath[S <: Shape, M](fg: FusedGraph[S, M], from: Attribute, to: Attribute): Unit = { + val starts = fg.module.info.allModules.filter(_.attributes.contains(from)) + starts.size should ===(1) + val start = starts.head + val ups = fg.module.info.upstreams + val owner = fg.module.info.outOwners + + @tailrec def rec(curr: Module): Unit = { + if (Debug) println(extractName(curr, "unknown")) + if (curr.attributes.contains(to)) () // done + else { + val outs = curr.inPorts.map(ups) + outs.size should ===(1) + val out = outs.head + val next = owner(out) + rec(next) + } + } + + rec(start) + } + "Fusing" must { - "fuse a moderately complex graph" in { - val g = Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1).alsoTo(Sink.fold(0)(_ + _)).to(Sink.fold(1)(_ + _)) - val fused = Fusing.aggressive(g) + def verify[S <: Shape, M](fused: FusedGraph[S, M], modules: Int, downstreams: Int): Unit = { val module = fused.module - module.subModules.size should ===(1) - module.info.downstreams.size should be > 5 - module.info.upstreams.size should be > 5 + module.subModules.size should ===(modules) + module.downstreams.size should ===(modules - 1) + module.info.downstreams.size should be >= downstreams + module.info.upstreams.size should be >= downstreams + singlePath(fused, Attributes.Name("mainSink"), Attributes.Name("unfoldInf")) + singlePath(fused, Attributes.Name("otherSink"), Attributes.Name("unfoldInf")) + } + + "fuse a moderately complex graph" in { + val g = graph(false) + val fused = Fusing.aggressive(g) + verify(fused, modules = 1, downstreams = 5) } "not fuse across AsyncBoundary" in { - val g = - Source.unfoldInf(1)(x ⇒ (x, x)).filter(_ % 2 == 1) - .alsoTo(Sink.fold(0)(_ + (_: Int)).addAttributes(Attributes.asyncBoundary)) - .to(Sink.fold(1)(_ + _)) + val g = graph(true) val fused = Fusing.aggressive(g) - val module = fused.module - module.subModules.size should ===(2) - module.info.downstreams.size should be > 5 - module.info.upstreams.size should be > 5 + verify(fused, modules = 2, downstreams = 5) + } + + "not fuse a FusedGraph again" in { + val g = Fusing.aggressive(graph(false)) + Fusing.aggressive(g) should be theSameInstanceAs g + } + + "properly fuse a FusedGraph that has been extended (no AsyncBoundary)" in { + val src = Fusing.aggressive(graph(false)) + val fused = Fusing.aggressive(Source.fromGraph(src).to(Sink.head)) + verify(fused, modules = 1, downstreams = 6) + } + + "properly fuse a FusedGraph that has been extended (with AsyncBoundary)" in { + val src = Fusing.aggressive(graph(true)) + val fused = Fusing.aggressive(Source.fromGraph(src).to(Sink.head)) + verify(fused, modules = 2, downstreams = 6) } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala index 9e8ce95a7e..36cab1a259 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala @@ -5,6 +5,7 @@ package akka.stream.impl import akka.stream.testkit.AkkaSpec import akka.stream._ +import akka.stream.Fusing.aggressive import akka.stream.scaladsl._ import akka.stream.stage._ import akka.stream.testkit.Utils.assertAllStagesStopped @@ -66,12 +67,6 @@ class GraphStageLogicSpec extends AkkaSpec with GraphInterpreterSpecKit with Con } } - class FusedGraph[S <: Shape](ga: GraphAssembly, s: S, a: Attributes = Attributes.none) extends Graph[S, Unit] { - override def shape = s - override val module = GraphModule(ga, s, a, ga.stages.map(_.module)) - override def withAttributes(attr: Attributes) = new FusedGraph(ga, s, attr) - } - "A GraphStageLogic" must { "emit all things before completing" in assertAllStagesStopped { @@ -84,38 +79,21 @@ class GraphStageLogicSpec extends AkkaSpec with GraphInterpreterSpecKit with Con } "emit all things before completing with two fused stages" in assertAllStagesStopped { - new Builder { - val g = new FusedGraph( - builder(emit1234, emit5678) - .connect(Upstream, emit1234.in) - .connect(emit1234.out, emit5678.in) - .connect(emit5678.out, Downstream) - .buildAssembly(), - FlowShape(emit1234.in, emit5678.out)) + val g = aggressive(Flow[Int].via(emit1234).via(emit5678)) - Source.empty.via(g).runWith(TestSink.probe) - .request(9) - .expectNextN(1 to 8) - .expectComplete() - } + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() } "emit all things before completing with three fused stages" in assertAllStagesStopped { - new Builder { - val g = new FusedGraph( - builder(emit1234, passThrough, emit5678) - .connect(Upstream, emit1234.in) - .connect(emit1234.out, passThrough.in) - .connect(passThrough.out, emit5678.in) - .connect(emit5678.out, Downstream) - .buildAssembly(), - FlowShape(emit1234.in, emit5678.out)) + val g = aggressive(Flow[Int].via(emit1234).via(passThrough).via(emit5678)) - Source.empty.via(g).runWith(TestSink.probe) - .request(9) - .expectNextN(1 to 8) - .expectComplete() - } + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() } "invoke lifecycle hooks in the right order" in assertAllStagesStopped { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala index 700f585122..bb0920b2b1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMatValueSpec.scala @@ -3,7 +3,7 @@ */ package akka.stream.scaladsl -import akka.stream.{ ClosedShape, SourceShape, ActorMaterializer, ActorMaterializerSettings } +import akka.stream._ import akka.stream.testkit._ import scala.concurrent.Await @@ -106,5 +106,15 @@ class GraphMatValueSpec extends AkkaSpec { } + "work also when the source’s module is copied" in { + val foldFlow: Flow[Int, Int, Future[Int]] = Flow.fromGraph(GraphDSL.create(Sink.fold[Int, Int](0)(_ + _)) { + implicit builder ⇒ + fold ⇒ + FlowShape(fold.inlet, builder.materializedValue.mapAsync(4)(identity).outlet) + }) + + Await.result(Source(1 to 10).via(foldFlow).runWith(Sink.head), 3.seconds) should ===(55) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index e2972a2c77..6f352746fa 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -7,6 +7,7 @@ import akka.event.Logging import scala.annotation.tailrec import scala.reflect.{ classTag, ClassTag } import akka.japi.function +import akka.stream.impl.StreamLayout._ /** * Holds attributes which can be used to alter [[akka.stream.scaladsl.Flow]] / [[akka.stream.javadsl.Flow]] @@ -221,6 +222,16 @@ object Attributes { def logLevels(onElement: Logging.LogLevel = Logging.DebugLevel, onFinish: Logging.LogLevel = Logging.DebugLevel, onFailure: Logging.LogLevel = Logging.ErrorLevel) = Attributes(LogLevels(onElement, onFinish, onFailure)) + /** + * Compute a name by concatenating all Name attributes that the given module + * has, returning the given default value if none are found. + */ + def extractName(mod: Module, default: String): String = { + mod match { + case CopiedModule(_, attr, copyOf) ⇒ (attr and copyOf.attributes).nameOrDefault(default) + case _ ⇒ mod.attributes.nameOrDefault(default) + } + } } /** diff --git a/akka-stream/src/main/scala/akka/stream/Fusing.scala b/akka-stream/src/main/scala/akka/stream/Fusing.scala index e84dfaece3..fb676b6ba8 100644 --- a/akka-stream/src/main/scala/akka/stream/Fusing.scala +++ b/akka-stream/src/main/scala/akka/stream/Fusing.scala @@ -7,6 +7,8 @@ import java.{ util ⇒ ju } import scala.collection.immutable import scala.collection.JavaConverters._ import akka.stream.impl.StreamLayout._ +import akka.stream.impl.fusing.{ Fusing ⇒ Impl } +import scala.annotation.unchecked.uncheckedVariance /** * This class holds some graph transformation functions that can fuse together @@ -32,15 +34,19 @@ object Fusing { * via [[akka.stream.Attributes#AsyncBoundary]]. */ def aggressive[S <: Shape, M](g: Graph[S, M]): FusedGraph[S, M] = - akka.stream.impl.fusing.Fusing.aggressive(g) + g match { + case fg: FusedGraph[_, _] ⇒ fg + case _ ⇒ Impl.aggressive(g) + } /** * A fused graph of the right shape, containing a [[FusedModule]] which * holds more information on the operation structure of the contained stream * topology for convenient graph traversal. */ - case class FusedGraph[S <: Shape, M](override val module: FusedModule, - override val shape: S) extends Graph[S, M] { + case class FusedGraph[+S <: Shape @uncheckedVariance, +M](override val module: FusedModule, + override val shape: S) extends Graph[S, M] { + // the @uncheckedVariance look like a compiler bug ... why does it work in Graph but not here? override def withAttributes(attr: Attributes) = copy(module = module.withAttributes(attr)) } @@ -54,6 +60,7 @@ object Fusing { final case class StructuralInfo(upstreams: immutable.Map[InPort, OutPort], downstreams: immutable.Map[OutPort, InPort], inOwners: immutable.Map[InPort, Module], - outOwners: immutable.Map[OutPort, Module]) + outOwners: immutable.Map[OutPort, Module], + allModules: Set[Module]) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index 0474aa89fa..b3766d87ec 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -24,7 +24,7 @@ import akka.stream.impl.fusing.GraphModule /** * INTERNAL API */ -private[akka] object StreamLayout { +object StreamLayout { // compile-time constant final val Debug = false @@ -122,7 +122,7 @@ private[akka] object StreamLayout { override def toString: String = s"Combine($dep1,$dep2)" } case class Atomic(module: Module) extends MaterializedValueNode { - override def toString: String = s"Atomic(${module.attributes.nameOrDefault(module.getClass.getName)})" + override def toString: String = s"Atomic(${module.attributes.nameOrDefault(module.getClass.getName)}[${module.hashCode}])" } case class Transform(f: Any ⇒ Any, dep: MaterializedValueNode) extends MaterializedValueNode { override def toString: String = s"Transform($dep)" @@ -148,6 +148,7 @@ private[akka] object StreamLayout { final def isBidiFlow: Boolean = (inPorts.size == 2) && (outPorts.size == 2) def isAtomic: Boolean = subModules.isEmpty def isCopied: Boolean = false + def isFused: Boolean = false /** * Fuses this Module to `that` Module by wiring together `from` and `to`, @@ -192,7 +193,7 @@ private[akka] object StreamLayout { else s"The input port [$to] is not part of the underlying graph.") CompositeModule( - subModules, + if (isSealed) Set(this) else subModules, AmorphousShape(shape.inlets.filterNot(_ == to), shape.outlets.filterNot(_ == from)), downstreams.updated(from, to), upstreams.updated(to, from), @@ -314,7 +315,7 @@ private[akka] object StreamLayout { } def subModules: Set[Module] - final def isSealed: Boolean = isAtomic || isCopied + final def isSealed: Boolean = isAtomic || isCopied || isFused def downstreams: Map[OutPort, InPort] = Map.empty def upstreams: Map[InPort, OutPort] = Map.empty @@ -411,6 +412,8 @@ private[akka] object StreamLayout { override val attributes: Attributes, info: Fusing.StructuralInfo) extends Module { + override def isFused: Boolean = true + override def replaceShape(s: Shape): Module = { shape.requireSamePortsAs(s) copy(shape = s) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala index d8905f982f..5032b13fb1 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Fusing.scala @@ -55,6 +55,7 @@ private[stream] object Fusing { * Perform the fusing of `struct.groups` into GraphModules (leaving them * as they are for non-fusable modules). */ + struct.removeInternalWires() struct.breakUpGroupsByDispatcher() val modules = fuse(struct) /* @@ -65,7 +66,7 @@ private[stream] object Fusing { shape, immutable.Map.empty ++ struct.downstreams.asScala, immutable.Map.empty ++ struct.upstreams.asScala, - matValue, + matValue.head._2, Attributes.none, info) @@ -244,26 +245,102 @@ private[stream] object Fusing { inheritedAttributes: Attributes, struct: BuildStructuralInfo, openGroup: ju.Set[Module], - indent: String): MaterializedValueNode = { + indent: String): List[(Module, MaterializedValueNode)] = { def log(msg: String): Unit = println(indent + msg) val async = m match { case _: GraphStageModule ⇒ m.attributes.contains(AsyncBoundary) + case _: GraphModule ⇒ m.attributes.contains(AsyncBoundary) case _ if m.isAtomic ⇒ true // non-GraphStage atomic or has AsyncBoundary case _ ⇒ m.attributes.contains(AsyncBoundary) } - if (Debug) log(s"entering ${m.getClass} (async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})") + if (Debug) log(s"entering ${m.getClass} (hash=${m.hashCode}, async=$async, name=${m.attributes.nameLifted}, dispatcher=${dispatcher(m)})") val localGroup = if (async) struct.newGroup(indent) else openGroup if (m.isAtomic) { - if (Debug) log(s"atomic module $m") - struct.addModule(m, localGroup, inheritedAttributes, indent) + m match { + case gm: GraphModule if !async ⇒ + // need to dissolve previously fused GraphStages to allow further fusion + if (Debug) log(s"dissolving graph module ${m.toString.replace("\n", "\n" + indent)}") + val attributes = inheritedAttributes and m.attributes + gm.matValIDs.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut) + case gm @ GraphModule(_, oldShape, _, mvids) ⇒ + /* + * Importing a GraphModule that has an AsyncBoundary attribute is a little more work: + * + * - we need to copy all the CopiedModules that are in matValIDs + * - we need to rewrite the corresponding MaterializedValueNodes + * - we need to match up the new (copied) GraphModule shape with the individual Shape copies + * - we need to register the contained modules but take care to not include the internal + * wirings into the final result, see also `struct.removeInternalWires()` + */ + if (Debug) log(s"graph module ${m.toString.replace("\n", "\n" + indent)}") + + // storing the old Shape in arrays for in-place updating as we clone the contained GraphStages + val oldIns = oldShape.inlets.toArray + val oldOuts = oldShape.outlets.toArray + + val newids = mvids.map { + case CopiedModule(shape, attr, copyOf) ⇒ + val newShape = shape.deepCopy + val copy = CopiedModule(newShape, attr, copyOf): Module + + // rewrite shape: first the inlets + val oldIn = shape.inlets.iterator + val newIn = newShape.inlets.iterator + while (oldIn.hasNext) { + val o = oldIn.next() + val n = newIn.next() + findInArray(o, oldIns) match { + case -1 ⇒ // nothing to do + case idx ⇒ oldIns(idx) = n + } + } + // ... then the outlets + val oldOut = shape.outlets.iterator + val newOut = newShape.outlets.iterator + while (oldOut.hasNext) { + val o = oldOut.next() + val n = newOut.next() + findInArray(o, oldOuts) match { + case -1 ⇒ // nothing to do + case idx ⇒ oldOuts(idx) = n + } + } + + // need to add the module so that the structural (internal) wirings can be rewritten as well + // but these modules must not be added to any of the groups + struct.addModule(copy, new ju.HashSet, inheritedAttributes, indent, shape) + struct.registerInteral(newShape, indent) + + copy + } + val newgm = gm.copy(shape = oldShape.copyFromPorts(oldIns.toList, oldOuts.toList), matValIDs = newids) + // make sure to add all the port mappings from old GraphModule Shape to new shape + struct.addModule(newgm, localGroup, inheritedAttributes, indent, _oldShape = oldShape) + // now compute the list of all materialized value computation updates + var result = List.empty[(Module, MaterializedValueNode)] + var i = 0 + while (i < mvids.length) { + result ::= mvids(i) -> Atomic(newids(i)) + i += 1 + } + result ::= m -> Atomic(newgm) + result + case _ ⇒ + if (Debug) log(s"atomic module $m") + List(m -> struct.addModule(m, localGroup, inheritedAttributes, indent)) + } } else { val attributes = inheritedAttributes and m.attributes m match { case CopiedModule(shape, _, copyOf) ⇒ - val ret = descend(copyOf, attributes, struct, localGroup, indent + " ") + val ret = + descend(copyOf, attributes, struct, localGroup, indent + " ") match { + case xs @ (_, mat) :: _ ⇒ (m -> mat) :: xs + case _ ⇒ throw new IllegalArgumentException("cannot happen") + } struct.rewire(copyOf.shape, shape, indent) ret case _ ⇒ @@ -272,15 +349,21 @@ private[stream] object Fusing { struct.enterMatCtx() // now descend into submodules and collect their computations (plus updates to `struct`) val subMat: Predef.Map[Module, MaterializedValueNode] = - m.subModules.map(sub ⇒ sub -> descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut) + m.subModules.flatMap(sub ⇒ descend(sub, attributes, struct, localGroup, indent + " "))(collection.breakOut) + if (Debug) log(subMat.map(p ⇒ s"${p._1.getClass.getName}[${p._1.hashCode}] -> ${p._2}").mkString("subMat\n " + indent, "\n " + indent, "")) // we need to remove all wirings that this module copied from nested modules so that we // don’t do wirings twice - val down = m.subModules.foldLeft(m.downstreams.toSet)((set, m) ⇒ set -- m.downstreams) + val oldDownstreams = m match { + case f: FusedModule ⇒ f.info.downstreams.toSet + case _ ⇒ m.downstreams.toSet + } + val down = m.subModules.foldLeft(oldDownstreams)((set, m) ⇒ set -- m.downstreams) down.foreach { case (start, end) ⇒ struct.wire(start, end, indent) } // now rewrite the materialized value computation based on the copied modules and their computation nodes - val newMat = rewriteMat(subMat, m.materializedValueComputation) + val matNodeMapping: ju.Map[MaterializedValueNode, MaterializedValueNode] = new ju.HashMap + val newMat = rewriteMat(subMat, m.materializedValueComputation, matNodeMapping) // and finally rewire all MaterializedValueSources to their new computation nodes val matSrcs = struct.exitMatCtx() matSrcs.foreach { c ⇒ @@ -288,33 +371,49 @@ private[stream] object Fusing { val ms = c.copyOf match { case g: GraphStageModule ⇒ g.stage.asInstanceOf[MaterializedValueSource[Any]] } - if (Debug) require(find(ms.computation, m.materializedValueComputation), s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}") - val replacement = CopiedModule(c.shape, c.attributes, new MaterializedValueSource[Any](newMat, ms.out).module) + val mapped = ms.computation match { + case Atomic(sub) ⇒ subMat(sub) + case other ⇒ matNodeMapping.get(other) + } + require(mapped != null, s"mismatch:\n ${ms.computation}\n ${m.materializedValueComputation}") + val newSrc = new MaterializedValueSource[Any](mapped, ms.out) + val replacement = CopiedModule(c.shape, c.attributes, newSrc.module) struct.replace(c, replacement, localGroup) } // the result for each level is the materialized value computation - newMat + List(m -> newMat) } } } - private def find(node: MaterializedValueNode, inTree: MaterializedValueNode): Boolean = - if (node == inTree) true - else - inTree match { - case Atomic(_) ⇒ false - case Ignore ⇒ false - case Transform(_, dep) ⇒ find(node, dep) - case Combine(_, left, right) ⇒ find(node, left) || find(node, right) - } + @tailrec + private def findInArray[T](elem: T, arr: Array[T], idx: Int = 0): Int = + if (idx >= arr.length) -1 + else if (arr(idx) == elem) idx + else findInArray(elem, arr, idx + 1) - private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode], - mat: MaterializedValueNode): MaterializedValueNode = + /** + * Given a mapping from old modules to new MaterializedValueNode, rewrite the given + * computation while also populating a mapping from old computation nodes to new ones. + * That mapping is needed to rewrite the MaterializedValueSource stages later-on in + * descend(). + */ + private def rewriteMat(subMat: Predef.Map[Module, MaterializedValueNode], mat: MaterializedValueNode, + mapping: ju.Map[MaterializedValueNode, MaterializedValueNode]): MaterializedValueNode = mat match { - case Atomic(sub) ⇒ subMat(sub) - case Combine(f, left, right) ⇒ Combine(f, rewriteMat(subMat, left), rewriteMat(subMat, right)) - case Transform(f, dep) ⇒ Transform(f, rewriteMat(subMat, dep)) - case Ignore ⇒ Ignore + case Atomic(sub) ⇒ + val ret = subMat(sub) + mapping.put(mat, ret) + ret + case Combine(f, left, right) ⇒ + val ret = Combine(f, rewriteMat(subMat, left, mapping), rewriteMat(subMat, right, mapping)) + mapping.put(mat, ret) + ret + case Transform(f, dep) ⇒ + val ret = Transform(f, rewriteMat(subMat, dep, mapping)) + mapping.put(mat, ret) + ret + case Ignore ⇒ Ignore } private implicit class NonNull[T](val x: T) extends AnyVal { @@ -335,7 +434,8 @@ private[stream] object Fusing { immutable.Map.empty ++ upstreams.asScala, immutable.Map.empty ++ downstreams.asScala, immutable.Map.empty ++ inOwners.asScala, - immutable.Map.empty ++ outOwners.asScala) + immutable.Map.empty ++ outOwners.asScala, + Set.empty ++ modules.asScala) /** * the set of all contained modules @@ -443,6 +543,34 @@ private[stream] object Fusing { */ val outOwners: ju.Map[OutPort, Module] = new ju.HashMap + /** + * List of internal wirings of GraphModules that were incorporated. + */ + val internalOuts: ju.Set[OutPort] = new ju.HashSet + + /** + * Register the outlets of the given Shape as sources for internal + * connections within imported (and not dissolved) GraphModules. + * See also the comment in addModule where this is partially undone. + */ + def registerInteral(s: Shape, indent: String): Unit = { + if (Debug) println(indent + s"registerInternals(${s.outlets.map(hash)})") + internalOuts.addAll(s.outlets.asJava) + } + + /** + * Remove wirings that belong to the fused stages contained in GraphModules + * that were incorporated in this fusing run. + */ + def removeInternalWires(): Unit = { + val it = internalOuts.iterator() + while (it.hasNext) { + val out = it.next() + val in = downstreams.remove(out) + if (in != null) upstreams.remove(in) + } + } + def dump(): Unit = { println("StructuralInfo:") println(" newIns:") @@ -467,13 +595,17 @@ private[stream] object Fusing { /** * Add a module to the given group, performing normalization (i.e. giving it a unique port identity). */ - def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String): Atomic = { - val copy = CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m)) - if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(m.shape)}") + def addModule(m: Module, group: ju.Set[Module], inheritedAttributes: Attributes, indent: String, + _oldShape: Shape = null): Atomic = { + val copy = + if (_oldShape == null) CopiedModule(m.shape.deepCopy(), inheritedAttributes, realModule(m)) + else m + val oldShape = if (_oldShape == null) m.shape else _oldShape + if (Debug) println(indent + s"adding copy ${hash(copy)} ${printShape(copy.shape)} of ${printShape(oldShape)}") group.add(copy) modules.add(copy) copy.shape.outlets.foreach(o ⇒ outGroup.put(o, group)) - val orig1 = m.shape.inlets.iterator + val orig1 = oldShape.inlets.iterator val mapd1 = copy.shape.inlets.iterator while (orig1.hasNext) { val orig = orig1.next() @@ -481,7 +613,7 @@ private[stream] object Fusing { addMapping(orig, mapd, newIns) inOwners.put(mapd, copy) } - val orig2 = m.shape.outlets.iterator + val orig2 = oldShape.outlets.iterator val mapd2 = copy.shape.outlets.iterator while (orig2.hasNext) { val orig = orig2.next() @@ -489,8 +621,24 @@ private[stream] object Fusing { addMapping(orig, mapd, newOuts) outOwners.put(mapd, copy) } - copy.copyOf match { - case GraphStageModule(_, _, _: MaterializedValueSource[_]) ⇒ pushMatSrc(copy) + /* + * In descend() we add internalOuts entries for all shapes that belong to stages that + * are part of a GraphModule that is not dissolved. This includes the exposed Outlets, + * which of course are external and thus need to be removed again from the internalOuts + * set. + */ + if (m.isInstanceOf[GraphModule]) internalOuts.removeAll(m.shape.outlets.asJava) + copy match { + case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c) + case GraphModule(_, _, _, mvids) ⇒ + var i = 0 + while (i < mvids.length) { + mvids(i) match { + case c @ CopiedModule(_, _, GraphStageModule(_, _, _: MaterializedValueSource[_])) ⇒ pushMatSrc(c) + case _ ⇒ + } + i += 1 + } case _ ⇒ } Atomic(copy) @@ -502,8 +650,8 @@ private[stream] object Fusing { */ def wire(out: OutPort, in: InPort, indent: String): Unit = { if (Debug) println(indent + s"wiring $out (${hash(out)}) -> $in (${hash(in)})") - val newOut = removeMapping(out, newOuts) nonNull out.toString - val newIn = removeMapping(in, newIns) nonNull in.toString + val newOut = removeMapping(out, newOuts) nonNull s"$out (${hash(out)})" + val newIn = removeMapping(in, newIns) nonNull s"$in (${hash(in)})" downstreams.put(newOut, newIn) upstreams.put(newIn, newOut) } @@ -514,10 +662,10 @@ private[stream] object Fusing { def rewire(oldShape: Shape, newShape: Shape, indent: String): Unit = { if (Debug) println(indent + s"rewiring ${printShape(oldShape)} -> ${printShape(newShape)}") oldShape.inlets.iterator.zip(newShape.inlets.iterator).foreach { - case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull oldIn.toString, newIns) + case (oldIn, newIn) ⇒ addMapping(newIn, removeMapping(oldIn, newIns) nonNull s"$oldIn (${hash(oldIn)})", newIns) } oldShape.outlets.iterator.zip(newShape.outlets.iterator).foreach { - case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull oldOut.toString, newOuts) + case (oldOut, newOut) ⇒ addMapping(newOut, removeMapping(oldOut, newOuts) nonNull s"$oldOut (${hash(oldOut)})", newOuts) } } @@ -534,6 +682,9 @@ private[stream] object Fusing { old.map(o ⇒ newOuts.get(o).head.outlet) } + /** + * Determine whether the given CopiedModule has an AsyncBoundary attribute. + */ private def isAsync(m: Module): Boolean = m match { case CopiedModule(_, inherited, orig) ⇒ val attr = inherited and orig.attributes @@ -550,6 +701,9 @@ private[stream] object Fusing { case x ⇒ x.attributes.get[ActorAttributes.Dispatcher] } + /** + * See through copied modules to the “real” module. + */ private def realModule(m: Module): Module = m match { case CopiedModule(_, _, of) ⇒ realModule(of) case other ⇒ other