diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/GraphDslTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/GraphDslTest.java index 6f1cfc99a7..a28a715e0d 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/GraphDslTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/GraphDslTest.java @@ -4,6 +4,7 @@ package akka.stream.javadsl; +import akka.Done; import akka.NotUsed; import akka.japi.Pair; import akka.stream.*; @@ -216,4 +217,14 @@ public class GraphDslTest extends StreamTest { assertEquals("bx", result.get(1).toCompletableFuture().get(1, TimeUnit.SECONDS)); assertEquals("cx", result.get(2).toCompletableFuture().get(1, TimeUnit.SECONDS)); } + + @Test + public void canUseMapMaterializedValueOnGraphs() { + Graph, NotUsed> srcGraph = Source.empty(); + Graph, Pair> mappedMatValueSrcGraph = + Graph.mapMaterializedValue(srcGraph, notUsed -> new Pair(notUsed, notUsed)); + Sink> snk = Sink.ignore(); + Pair pair = Source.fromGraph(mappedMatValueSrcGraph).to(snk).run(system); + assertEquals(pair, new Pair(NotUsed.getInstance(), NotUsed.getInstance())); + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphDSLCompileSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphDSLCompileSpec.scala index 7d79a9edb7..e4a5de307a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphDSLCompileSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphDSLCompileSpec.scala @@ -4,6 +4,7 @@ package akka.stream.scaladsl +import akka.NotUsed import akka.stream.impl.fusing.GraphStages import akka.stream._ import akka.stream.testkit._ @@ -411,5 +412,15 @@ class GraphDSLCompileSpec extends StreamSpec { ga.traversalBuilder.attributes.getFirst[Name] shouldEqual Some(Name("useless")) ga.traversalBuilder.attributes.getFirst[AsyncBoundary.type] shouldEqual (Some(AsyncBoundary)) } + + "support mapMaterializedValue" in { + val anOp = op[String, String] + val anOpWithMappedMatVal = anOp.mapMaterializedValue { + case NotUsed => (NotUsed, NotUsed) + } + val g = Source.empty[String].viaMat(anOpWithMappedMatVal)(Keep.right).to(Sink.cancelled) + val matVal = g.run() + matVal shouldEqual ((NotUsed, NotUsed)) + } } } diff --git a/akka-stream/src/main/scala/akka/stream/Graph.scala b/akka-stream/src/main/scala/akka/stream/Graph.scala index 51e67abe1f..a7b1c40f81 100644 --- a/akka-stream/src/main/scala/akka/stream/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/Graph.scala @@ -6,6 +6,7 @@ package akka.stream import akka.annotation.InternalApi import akka.stream.impl.TraversalBuilder +import akka.stream.scaladsl.GenericGraph import scala.annotation.unchecked.uncheckedVariance @@ -69,6 +70,34 @@ trait Graph[+S <: Shape, +M] { def addAttributes(attr: Attributes): Graph[S, M] = withAttributes(traversalBuilder.attributes and attr) } +object Graph { + + /** + * Java API + * Transform the materialized value of this Flow, leaving all other properties as they were. + * + * @param g the graph being transformed + * @param f function to map the graph's materialized value + * @return a graph with same semantics as the given graph, except from the materialized value which is mapped using f. + */ + def mapMaterializedValue[S <: Shape, M1, M2](g: Graph[S, M1])(f: M1 => M2): Graph[S, M2] = + new GenericGraph(g.shape, g.traversalBuilder).mapMaterializedValue(f) + + /** + * Scala API, see https://github.com/akka/akka/issues/28501 for discussion why this can't be an instance method on class Graph. + * @param self the graph whose materialized value will be mapped + */ + final implicit class GraphMapMatVal[S <: Shape, M](self: Graph[S, M]) { + + /** + * Transform the materialized value of this Graph, leaving all other properties as they were. + * + * @param f function to map the graph's materialized value + */ + def mapMaterializedValue[M2](f: M => M2): Graph[S, M2] = Graph.mapMaterializedValue(self)(f) + } +} + /** * INTERNAL API * diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 70b4637afa..76f35f00fa 100755 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -37,6 +37,9 @@ private[stream] final class GenericGraph[S <: Shape, Mat]( override def withAttributes(attr: Attributes): Graph[S, Mat] = new GenericGraphWithChangedAttributes(shape, traversalBuilder, attr) + + def mapMaterializedValue[Mat2](f: Mat => Mat2): GenericGraph[S, Mat2] = + new GenericGraph(shape, traversalBuilder.transformMat(f)) } /**