diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/BidiFlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/BidiFlowSpec.scala index 75449a3cd9..1101dd5227 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/BidiFlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/BidiFlowSpec.scala @@ -7,9 +7,7 @@ package akka.stream.scaladsl import scala.collection.immutable import scala.concurrent.Await import scala.concurrent.duration._ - import scala.annotation.nowarn - import akka.NotUsed import akka.stream._ import akka.stream.testkit.StreamSpec @@ -121,6 +119,47 @@ class BidiFlowSpec extends StreamSpec { b.traversalBuilder.attributes.getFirst[AsyncBoundary.type] shouldEqual Some(AsyncBoundary) } + "short circuit identity in atop" in { + val myBidi = BidiFlow.fromFlows(Flow[Long].map(_ + 1L), Flow[ByteString]) + val identity = BidiFlow.identity[Long, ByteString] + + // simple ones + myBidi.atop(identity) should ===(myBidi) + identity.atopMat(myBidi)(Keep.right) should ===(myBidi) + + // optimized but not the same instance (because myBidi mat value is dropped) + identity.atop(myBidi) should !==(myBidi) + myBidi.atopMat(identity)(Keep.right) should !==(myBidi) + } + + "semi-shortcuted atop with identity should still work" in { + // atop when the NotUsed matval is kept from identity has a smaller optimization, so verify they still work + val myBidi = + BidiFlow.fromFlows(Flow[Long].map(_ + 1L), Flow[Long].map(_ + 1L)).mapMaterializedValue(_ => "bidi-matval") + val identity = BidiFlow.identity[Long, Long] + + def verify[M](atopBidi: BidiFlow[Long, Long, Long, Long, M], expectedMatVal: M): Unit = { + val joinedFlow = atopBidi.joinMat(Flow[Long])(Keep.left) + val (bidiMatVal, seqSinkMatValF) = + Source(1L :: 2L :: Nil).viaMat(joinedFlow)(Keep.right).toMat(Sink.seq)(Keep.both).run() + seqSinkMatValF.futureValue should ===(Seq(3L, 4L)) + bidiMatVal should ===(expectedMatVal) + } + + // identity atop myBidi + verify(identity.atopMat(myBidi)(Keep.left), NotUsed) + verify(identity.atopMat(myBidi)(Keep.none), NotUsed) + verify(identity.atopMat(myBidi)(Keep.right), "bidi-matval") + // arbitrary matval combine + verify(identity.atopMat(myBidi)((_, m) => m), "bidi-matval") + + // myBidi atop identity + verify(myBidi.atopMat(identity)(Keep.left), "bidi-matval") + verify(myBidi.atopMat(identity)(Keep.none), NotUsed) + verify(myBidi.atopMat(identity)(Keep.right), NotUsed) + verify(myBidi.atopMat(identity)((m, _) => m), "bidi-matval") + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala index 091296f783..483cea1cc1 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala @@ -5,7 +5,6 @@ package akka.stream.scaladsl import scala.concurrent.duration.FiniteDuration - import akka.NotUsed import akka.stream.{ BidiShape, _ } import akka.stream.impl.{ LinearTraversalBuilder, Timers, TraversalBuilder } @@ -59,24 +58,48 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat]( * flow into the materialized value of the resulting BidiFlow. */ def atopMat[OO1, II2, Mat2, M](bidi: Graph[BidiShape[O1, OO1, II2, I2], Mat2])( - combine: (Mat, Mat2) => M): BidiFlow[I1, OO1, II2, O2, M] = { - val newBidi1Shape = shape.deepCopy() - val newBidi2Shape = bidi.shape.deepCopy() + combine: (Mat, Mat2) => M): BidiFlow[I1, OO1, II2, O2, M] = + if (this eq BidiFlow.identity) { + // optimizations possible since we know that identity matval is NotUsed + if (combine eq Keep.right) + BidiFlow.fromGraph(bidi).asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + else if ((combine eq Keep.left) || (combine eq Keep.none)) + BidiFlow.fromGraph(bidi).mapMaterializedValue(_ => NotUsed).asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + else { + BidiFlow + .fromGraph(bidi) + .mapMaterializedValue(mat2 => combine(NotUsed.asInstanceOf[Mat], mat2)) + .asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + } + } else if (bidi eq BidiFlow.identity) { + // optimizations possible since we know that identity matval is NotUsed + if (combine eq Keep.left) + this.asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + else if ((combine eq Keep.right) || (combine eq Keep.none)) + this.mapMaterializedValue(_ => NotUsed).asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + else { + this + .mapMaterializedValue(mat => combine(mat, NotUsed.asInstanceOf[Mat2])) + .asInstanceOf[BidiFlow[I1, OO1, II2, O2, M]] + } + } else { + val newBidi1Shape = shape.deepCopy() + val newBidi2Shape = bidi.shape.deepCopy() - // We MUST add the current module as an explicit submodule. The composite builder otherwise *grows* the - // existing module, which is not good if there are islands present (the new module will "join" the island). - val newTraversalBuilder = - TraversalBuilder - .empty() - .add(traversalBuilder, newBidi1Shape, Keep.right) - .add(bidi.traversalBuilder, newBidi2Shape, combine) - .wire(newBidi1Shape.out1, newBidi2Shape.in1) - .wire(newBidi2Shape.out2, newBidi1Shape.in2) + // We MUST add the current module as an explicit submodule. The composite builder otherwise *grows* the + // existing module, which is not good if there are islands present (the new module will "join" the island). + val newTraversalBuilder = + TraversalBuilder + .empty() + .add(traversalBuilder, newBidi1Shape, Keep.right) + .add(bidi.traversalBuilder, newBidi2Shape, combine) + .wire(newBidi1Shape.out1, newBidi2Shape.in1) + .wire(newBidi2Shape.out2, newBidi1Shape.in2) - new BidiFlow( - newTraversalBuilder, - BidiShape(newBidi1Shape.in1, newBidi2Shape.out1, newBidi2Shape.in2, newBidi1Shape.out2)) - } + new BidiFlow( + newTraversalBuilder, + BidiShape(newBidi1Shape.in1, newBidi2Shape.out1, newBidi2Shape.in2, newBidi1Shape.out2)) + } /** * Add the given Flow as the final step in a bidirectional transformation