diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala index 017372e67c..9700abe95a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala @@ -5,7 +5,7 @@ package akka.stream.impl import akka.stream.scaladsl._ import akka.stream.testkit.AkkaSpec -import org.reactivestreams.{ Subscription, Subscriber, Publisher } +import org.reactivestreams.{ Processor, Subscription, Subscriber, Publisher } import akka.stream._ class StreamLayoutSpec extends AkkaSpec { @@ -188,6 +188,7 @@ class StreamLayoutSpec extends AkkaSpec { assignPort(outPort, publisher) } } + override protected def createIdentityProcessor: Processor[Any, Any] = null // Not used in test } def checkMaterialized(topLevel: Module): (Set[TestPublisher], Set[TestSubscriber]) = { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala index f326bddce9..1c3c49dd7b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala @@ -37,6 +37,12 @@ class PublisherSinkSpec extends AkkaSpec { Await.result(f2, 3.seconds) should be(15) } + + "work with SubscriberSource" in { + val (sub, pub) = Source.subscriber[Int].toMat(Sink.publisher)(Keep.both).run() + Source(1 to 100).to(Sink(sub)).run() + Await.result(Source(pub).grouped(1000).runWith(Sink.head), 3.seconds) should ===(1 to 100) + } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index daf87ae023..a82bef711e 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -111,6 +111,9 @@ private[akka] case class ActorFlowMaterializerImpl( } } + override protected def createIdentityProcessor: Processor[Any, Any] = + processorFor(Identity(OperationAttributes.none), OperationAttributes.none, settings)._1 + private def processorFor(op: StageModule, effectiveAttributes: OperationAttributes, effectiveSettings: ActorFlowMaterializerSettings): (Processor[Any, Any], Any) = op match { diff --git a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala index cd842ef811..3414d375fe 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala @@ -55,17 +55,7 @@ private[akka] abstract class SourceModule[+Out, +Mat](val shape: SourceShape[Out private[akka] final class SubscriberSource[Out](val attributes: OperationAttributes, shape: SourceShape[Out]) extends SourceModule[Out, Subscriber[Out]](shape) { override def create(context: MaterializationContext): (Publisher[Out], Subscriber[Out]) = { - val processor = new Processor[Out, Out] { - @volatile private var subscriber: Subscriber[_ >: Out] = null - - override def subscribe(s: Subscriber[_ >: Out]): Unit = subscriber = s - - override def onError(t: Throwable): Unit = subscriber.onError(t) - override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) - override def onComplete(): Unit = subscriber.onComplete() - override def onNext(t: Out): Unit = subscriber.onNext(t) - } - + val processor = new SubscriberSourceVirtualProcessor[Out] (processor, processor) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index 86ca3b2a17..c1ce5cfe27 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -60,8 +60,8 @@ private[akka] class PublisherSink[In](val attributes: OperationAttributes, shape override def toString: String = "PublisherSink" override def create(context: MaterializationContext): (Subscriber[In], Publisher[In]) = { - val pub = new VirtualPublisher[In] - val sub = new VirtualSubscriber[In](pub) + val pub = new PublisherSinkVirtualPublisher[In] + val sub = new PublisherSinkVirtualSubscriber[In](pub) (sub, pub) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index 0e3a704e16..a8967360d9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -8,7 +8,7 @@ import java.util.concurrent.atomic.{ AtomicInteger, AtomicBoolean, AtomicReferen import akka.stream.impl.StreamLayout.Module import akka.stream.scaladsl.Keep import akka.stream._ -import org.reactivestreams.{ Subscription, Publisher, Subscriber } +import org.reactivestreams.{ Processor, Subscription, Publisher, Subscriber } import scala.collection.mutable import scala.util.control.NonFatal import akka.event.Logging.simpleName @@ -293,10 +293,21 @@ private[akka] object StreamLayout { } } +private[stream] final class SubscriberSourceVirtualProcessor[T] extends Processor[T, T] { + @volatile private var subscriber: Subscriber[_ >: T] = null + + override def subscribe(s: Subscriber[_ >: T]): Unit = subscriber = s + + override def onError(t: Throwable): Unit = subscriber.onError(t) + override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) + override def onComplete(): Unit = subscriber.onComplete() + override def onNext(t: T): Unit = subscriber.onNext(t) +} + /** * INTERNAL API */ -private[stream] class VirtualSubscriber[T](val owner: VirtualPublisher[T]) extends Subscriber[T] { +private[stream] final class PublisherSinkVirtualSubscriber[T](val owner: PublisherSinkVirtualPublisher[T]) extends Subscriber[T] { override def onSubscribe(s: Subscription): Unit = throw new UnsupportedOperationException("This method should not be called") override def onError(t: Throwable): Unit = throw new UnsupportedOperationException("This method should not be called") override def onComplete(): Unit = throw new UnsupportedOperationException("This method should not be called") @@ -306,7 +317,7 @@ private[stream] class VirtualSubscriber[T](val owner: VirtualPublisher[T]) exten /** * INTERNAL API */ -private[stream] class VirtualPublisher[T]() extends Publisher[T] { +private[stream] final class PublisherSinkVirtualPublisher[T]() extends Publisher[T] { @volatile var realPublisher: Publisher[T] = null override def subscribe(s: Subscriber[_ >: T]): Unit = { val sub = realPublisher.subscribe(s) @@ -320,7 +331,7 @@ private[stream] class VirtualPublisher[T]() extends Publisher[T] { /** * INTERNAL API */ -private[stream] case class MaterializedValueSource[M]( +private[stream] final case class MaterializedValueSource[M]( shape: SourceShape[M] = SourceShape[M](new Outlet[M]("Materialized.out")), attributes: OperationAttributes = OperationAttributes.name("Materialized")) extends StreamLayout.Module { @@ -562,6 +573,8 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo protected def materializeAtomic(atomic: Module, effectiveAttributes: OperationAttributes): Any + protected def createIdentityProcessor: Processor[Any, Any] + private def resolveMaterialized(matNode: MaterializedValueNode, materializedValues: collection.Map[Module, Any]): Any = matNode match { case Atomic(m) ⇒ materializedValues(m) case Combine(f, d1, d2) ⇒ f(resolveMaterialized(d1, materializedValues), resolveMaterialized(d2, materializedValues)) @@ -570,8 +583,15 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo } private def attach(p: Publisher[Any], s: Subscriber[Any]) = s match { - case v: VirtualSubscriber[Any] ⇒ v.owner.realPublisher = p - case _ ⇒ p.subscribe(s) + case v: PublisherSinkVirtualSubscriber[Any] ⇒ + if (p.isInstanceOf[SubscriberSourceVirtualProcessor[Any]]) { + val injectedProcessor = createIdentityProcessor + v.owner.realPublisher = injectedProcessor + p.subscribe(injectedProcessor) + } else + v.owner.realPublisher = p + case _ ⇒ + p.subscribe(s) } final protected def assignPort(in: InPort, subscriber: Subscriber[Any]): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala index e6a556463b..199e8eca8b 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala @@ -65,7 +65,7 @@ object Sink { */ def foreach[T](f: function.Procedure[T]): Sink[T, Future[Unit]] = new Sink(scaladsl.Sink.foreach(f.apply)) - + /** * A `Sink` that will invoke the given procedure for each received element in parallel. The sink is materialized * into a [[scala.concurrent.Future]].