diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala index b8c1b62b57..f6783982a1 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/AkkaIdentityProcessorVerification.scala @@ -30,7 +30,10 @@ abstract class AkkaIdentityProcessorVerification[T](env: TestEnvironment, publis def processorFromFlow(flow: Flow[T, T, _])(implicit mat: ActorFlowMaterializer): Processor[T, T] = { val (sub: Subscriber[T], pub: Publisher[T]) = flow.runWith(Source.subscriber[T], Sink.publisher[T]) + processorFromSubscriberAndPublisher(sub, pub) + } + def processorFromSubscriberAndPublisher(sub: Subscriber[T], pub: Publisher[T]): Processor[T, T] = { new Processor[T, T] { override def onSubscribe(s: Subscription): Unit = sub.onSubscribe(s) override def onError(t: Throwable): Unit = sub.onError(t) diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala new file mode 100644 index 0000000000..2bf5dad809 --- /dev/null +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/VirtualPublisherTest.scala @@ -0,0 +1,34 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.tck + +import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl.Flow +import org.reactivestreams.Processor +import akka.stream.impl.VirtualProcessor + +class VirtualProcessorTest extends AkkaIdentityProcessorVerification[Int] { + + override def createIdentityProcessor(maxBufferSize: Int): Processor[Int, Int] = { + implicit val materializer = ActorFlowMaterializer()(system) + + val identity = processorFromFlow(Flow[Int].map(elem ⇒ elem).named("identity")) + val left, right = new VirtualProcessor[Int] + left.subscribe(identity) + identity.subscribe(right) + processorFromSubscriberAndPublisher(left, right) + } + + override def createElement(element: Int): Int = element + +} + +class VirtualProcessorSingleTest extends AkkaIdentityProcessorVerification[Int] { + + override def createIdentityProcessor(maxBufferSize: Int): Processor[Int, Int] = + new VirtualProcessor[Int] + + override def createElement(element: Int): Int = element + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala index 9fd0f0fe09..e645e1f7d3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala @@ -186,7 +186,6 @@ class StreamLayoutSpec extends AkkaSpec { assignPort(outPort, publisher) } } - override protected def createIdentityProcessor: Processor[Any, Any] = null // Not used in test } def checkMaterialized(topLevel: Module): (Set[TestPublisher], Set[TestSubscriber]) = { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala index 1c3c49dd7b..83eb079ad7 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/PublisherSinkSpec.scala @@ -43,6 +43,13 @@ class PublisherSinkSpec extends AkkaSpec { Source(1 to 100).to(Sink(sub)).run() Await.result(Source(pub).grouped(1000).runWith(Sink.head), 3.seconds) should ===(1 to 100) } + + "be able to use Publisher in materialized value transformation" in { + val f = Source(1 to 3).runWith( + Sink.publisher[Int].mapMaterializedValue(p ⇒ Source(p).runFold(0)(_ + _))) + + Await.result(f, 3.seconds) should be(6) + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala new file mode 100644 index 0000000000..33d3dde68c --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala @@ -0,0 +1,29 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.ActorFlowMaterializer + +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.Utils._ +import scala.concurrent.duration._ + +import scala.concurrent.Await + +class SubscriberSourceSpec extends AkkaSpec("akka.loglevel=DEBUG\nakka.actor.debug.lifecycle=on") { + + implicit val materializer = ActorFlowMaterializer() + + "A SubscriberSource" must { + + "be able to use Subscriber in materialized value transformation" in { + val f = + Source.subscriber[Int].mapMaterializedValue(s ⇒ Source(1 to 3).runWith(Sink(s))) + .runWith(Sink.fold[Int, Int](0)(_ + _)) + + Await.result(f, 3.seconds) should be(6) + } + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index a82bef711e..76650e838d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -111,13 +111,11 @@ private[akka] case class ActorFlowMaterializerImpl( } } - override protected def createIdentityProcessor: Processor[Any, Any] = - processorFor(Identity(OperationAttributes.none), OperationAttributes.none, settings)._1 - private def processorFor(op: StageModule, effectiveAttributes: OperationAttributes, effectiveSettings: ActorFlowMaterializerSettings): (Processor[Any, Any], Any) = op match { case DirectProcessor(processorFactory, _) ⇒ processorFactory() + case Identity(attr) ⇒ (new VirtualProcessor, ()) case _ ⇒ val (opprops, mat) = ActorProcessorFactory.props(ActorFlowMaterializerImpl.this, op, effectiveAttributes) val processor = ActorProcessorFactory[Any, Any](actorOf( @@ -294,7 +292,7 @@ private[akka] object ActorProcessorFactory { // Also, otherwise the attributes will not affect the settings properly! val settings = materializer.effectiveSettings(att) op match { - case Identity(_) ⇒ (ActorInterpreter.props(settings, List(fusing.Map(_identity, settings.supervisionDecider)), materializer, att), ()) + case Identity(_) ⇒ throw new AssertionError("Identity cannot end up in ActorProcessorFactory") case Fused(ops, _) ⇒ (ActorInterpreter.props(settings, ops, materializer, att), ()) case Map(f, _) ⇒ (ActorInterpreter.props(settings, List(fusing.Map(f, settings.supervisionDecider)), materializer, att), ()) case Filter(p, _) ⇒ (ActorInterpreter.props(settings, List(fusing.Filter(p, settings.supervisionDecider)), materializer, att), ()) diff --git a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala index 3414d375fe..48c1f4c22c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Modules.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Modules.scala @@ -55,7 +55,7 @@ private[akka] abstract class SourceModule[+Out, +Mat](val shape: SourceShape[Out private[akka] final class SubscriberSource[Out](val attributes: OperationAttributes, shape: SourceShape[Out]) extends SourceModule[Out, Subscriber[Out]](shape) { override def create(context: MaterializationContext): (Publisher[Out], Subscriber[Out]) = { - val processor = new SubscriberSourceVirtualProcessor[Out] + val processor = new VirtualProcessor[Out] (processor, processor) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index c1ce5cfe27..51d7a5def9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -60,9 +60,8 @@ private[akka] class PublisherSink[In](val attributes: OperationAttributes, shape override def toString: String = "PublisherSink" override def create(context: MaterializationContext): (Subscriber[In], Publisher[In]) = { - val pub = new PublisherSinkVirtualPublisher[In] - val sub = new PublisherSinkVirtualSubscriber[In](pub) - (sub, pub) + val proc = new VirtualProcessor[In] + (proc, proc) } override protected def newInstance(shape: SinkShape[In]): SinkModule[In, Publisher[In]] = new PublisherSink[In](attributes, shape) diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index a8967360d9..473ef4bce4 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -4,7 +4,6 @@ package akka.stream.impl import java.util.concurrent.atomic.{ AtomicInteger, AtomicBoolean, AtomicReference } - import akka.stream.impl.StreamLayout.Module import akka.stream.scaladsl.Keep import akka.stream._ @@ -12,6 +11,8 @@ import org.reactivestreams.{ Processor, Subscription, Publisher, Subscriber } import scala.collection.mutable import scala.util.control.NonFatal import akka.event.Logging.simpleName +import scala.annotation.tailrec +import java.util.concurrent.atomic.AtomicLong /** * INTERNAL API @@ -106,9 +107,8 @@ private[akka] object StreamLayout { AmorphousShape(shape.inlets ++ that.shape.inlets, shape.outlets ++ that.shape.outlets), downstreams ++ that.downstreams, upstreams ++ that.upstreams, - if (f eq Keep.left) matComputation1 - else if (f eq Keep.right) matComputation2 - else Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), + // would like to optimize away this allocation for Keep.{left,right} but that breaks side-effecting transformations + Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), attributes) } @@ -293,38 +293,108 @@ private[akka] object StreamLayout { } } -private[stream] final class SubscriberSourceVirtualProcessor[T] extends Processor[T, T] { - @volatile private var subscriber: Subscriber[_ >: T] = null +private[stream] object VirtualProcessor { + sealed trait Termination + case object Allowed extends Termination + case object Completed extends Termination + case class Failed(ex: Throwable) extends Termination - override def subscribe(s: Subscriber[_ >: T]): Unit = subscriber = s - - override def onError(t: Throwable): Unit = subscriber.onError(t) - override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) - override def onComplete(): Unit = subscriber.onComplete() - override def onNext(t: T): Unit = subscriber.onNext(t) + private object InertSubscriber extends Subscriber[Any] { + override def onSubscribe(s: Subscription): Unit = s.cancel() + override def onNext(elem: Any): Unit = () + override def onError(thr: Throwable): Unit = () + override def onComplete(): Unit = () + } } -/** - * INTERNAL API - */ -private[stream] final class PublisherSinkVirtualSubscriber[T](val owner: PublisherSinkVirtualPublisher[T]) extends Subscriber[T] { - override def onSubscribe(s: Subscription): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onError(t: Throwable): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onComplete(): Unit = throw new UnsupportedOperationException("This method should not be called") - override def onNext(t: T): Unit = throw new UnsupportedOperationException("This method should not be called") -} +private[stream] final class VirtualProcessor[T] extends Processor[T, T] { + import VirtualProcessor._ + import ReactiveStreamsCompliance._ + + private val subscriptionStatus = new AtomicReference[AnyRef] + private val terminationStatus = new AtomicReference[Termination] -/** - * INTERNAL API - */ -private[stream] final class PublisherSinkVirtualPublisher[T]() extends Publisher[T] { - @volatile var realPublisher: Publisher[T] = null override def subscribe(s: Subscriber[_ >: T]): Unit = { - val sub = realPublisher.subscribe(s) - // unreference the realPublisher to facilitate GC and - // Sink.publisher is supposed to reject additional subscribers anyway - realPublisher = RejectAdditionalSubscribers[T] - sub + requireNonNullSubscriber(s) + if (subscriptionStatus.compareAndSet(null, s)) () // wait for onSubscribe + else + subscriptionStatus.get match { + case sub: Subscriber[_] ⇒ rejectAdditionalSubscriber(s, "VirtualProcessor") + case sub: Sub ⇒ + try { + subscriptionStatus.set(s) + tryOnSubscribe(s, sub) + sub.closeLatch() // allow onNext only now + terminationStatus.getAndSet(Allowed) match { + case null ⇒ // nothing happened yet + case Completed ⇒ tryOnComplete(s) + case Failed(ex) ⇒ tryOnError(s, ex) + case Allowed ⇒ // all good + } + } catch { + case NonFatal(ex) ⇒ sub.cancel() + } + } + } + + override def onSubscribe(s: Subscription): Unit = { + requireNonNullSubscription(s) + val wrapped = new Sub(s) + if (subscriptionStatus.compareAndSet(null, wrapped)) () // wait for Subscriber + else + subscriptionStatus.get match { + case sub: Subscriber[_] ⇒ + terminationStatus.get match { + case Allowed ⇒ + /* + * There is a race condition here: if this thread reads the subscriptionStatus after + * set set() in subscribe() but then sees the terminationStatus before the getAndSet() + * is published then we will rely upon the downstream Subscriber for cancelling this + * Subscription. I only mention this because the TCK requires that we handle this here + * (since the manualSubscriber used there does not expose this behavior). + */ + s.cancel() + case _ ⇒ + tryOnSubscribe(sub, wrapped) + wrapped.closeLatch() // allow onNext only now + terminationStatus.set(Allowed) + } + case sub: Subscription ⇒ + s.cancel() // reject further Subscriptions + } + } + + override def onError(t: Throwable): Unit = { + requireNonNullException(t) + if (terminationStatus.compareAndSet(null, Failed(t))) () // let it be picked up by subscribe() + else tryOnError(subscriptionStatus.get.asInstanceOf[Subscriber[T]], t) + } + + override def onComplete(): Unit = + if (terminationStatus.compareAndSet(null, Completed)) () // let it be picked up by subscribe() + else tryOnComplete(subscriptionStatus.get.asInstanceOf[Subscriber[T]]) + + override def onNext(t: T): Unit = { + requireNonNullElement(t) + tryOnNext(subscriptionStatus.get.asInstanceOf[Subscriber[T]], t) + } + + private final class Sub(s: Subscription) extends AtomicLong with Subscription { + override def cancel(): Unit = { + subscriptionStatus.set(InertSubscriber) + s.cancel() + } + @tailrec + override def request(n: Long): Unit = { + val current = get + if (current < 0) s.request(n) + else if (compareAndSet(current, current + n)) () + else request(n) + } + def closeLatch(): Unit = { + val requested = getAndSet(-1) + if (requested > 0) s.request(requested) + } } } @@ -550,6 +620,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo case mv: MaterializedValueSource[_] ⇒ val pub = new MaterializedValuePublisher materializedValuePublishers ::= pub + materializedValues.put(mv, ()) assignPort(mv.shape.outlet, pub) case atomic if atomic.isAtomic ⇒ materializedValues.put(atomic, materializeAtomic(atomic, subEffectiveAttributes)) @@ -573,8 +644,6 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo protected def materializeAtomic(atomic: Module, effectiveAttributes: OperationAttributes): Any - protected def createIdentityProcessor: Processor[Any, Any] - private def resolveMaterialized(matNode: MaterializedValueNode, materializedValues: collection.Map[Module, Any]): Any = matNode match { case Atomic(m) ⇒ materializedValues(m) case Combine(f, d1, d2) ⇒ f(resolveMaterialized(d1, materializedValues), resolveMaterialized(d2, materializedValues)) @@ -582,24 +651,12 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo case Ignore ⇒ () } - private def attach(p: Publisher[Any], s: Subscriber[Any]) = s match { - case v: PublisherSinkVirtualSubscriber[Any] ⇒ - if (p.isInstanceOf[SubscriberSourceVirtualProcessor[Any]]) { - val injectedProcessor = createIdentityProcessor - v.owner.realPublisher = injectedProcessor - p.subscribe(injectedProcessor) - } else - v.owner.realPublisher = p - case _ ⇒ - p.subscribe(s) - } - final protected def assignPort(in: InPort, subscriber: Subscriber[Any]): Unit = { subscribers(in) = subscriber // Interface (unconnected) ports of the current scope will be wired when exiting the scope if (!currentLayout.inPorts(in)) { val publisher = publishers(currentLayout.upstreams(in)) - if (publisher ne null) attach(publisher, subscriber) + if (publisher ne null) publisher.subscribe(subscriber) } } @@ -608,7 +665,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo // Interface (unconnected) ports of the current scope will be wired when exiting the scope if (!currentLayout.outPorts(out)) { val subscriber = subscribers(currentLayout.downstreams(out)) - if (subscriber ne null) attach(publisher, subscriber) + if (subscriber ne null) publisher.subscribe(subscriber) } }