diff --git a/akka-actor-tests/src/test/scala/akka/actor/FunctionRefSpec.scala b/akka-actor-tests/src/test/scala/akka/actor/FunctionRefSpec.scala new file mode 100644 index 0000000000..f3c608d8d3 --- /dev/null +++ b/akka-actor-tests/src/test/scala/akka/actor/FunctionRefSpec.scala @@ -0,0 +1,100 @@ +/** + * Copyright (C) 2016 Typesafe Inc. + */ +package akka.actor + +import akka.testkit.AkkaSpec +import akka.testkit.ImplicitSender +import scala.concurrent.duration._ +import akka.testkit.EventFilter + +object FunctionRefSpec { + + case class GetForwarder(replyTo: ActorRef) + case class DropForwarder(ref: FunctionRef) + case class Forwarded(msg: Any, sender: ActorRef) + + class Super extends Actor { + def receive = { + case GetForwarder(replyTo) ⇒ + val cell = context.asInstanceOf[ActorCell] + val ref = cell.addFunctionRef((sender, msg) ⇒ replyTo ! Forwarded(msg, sender)) + replyTo ! ref + case DropForwarder(ref) ⇒ + val cell = context.asInstanceOf[ActorCell] + cell.removeFunctionRef(ref) + } + } + + class SupSuper extends Actor { + val s = context.actorOf(Props[Super], "super") + def receive = { + case msg ⇒ s ! msg + } + } + +} + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class FunctionRefSpec extends AkkaSpec with ImplicitSender { + import FunctionRefSpec._ + + def commonTests(s: ActorRef) = { + s ! GetForwarder(testActor) + val forwarder = expectMsgType[FunctionRef] + + "forward messages" in { + forwarder ! "hello" + expectMsg(Forwarded("hello", testActor)) + } + + "be watchable" in { + s ! GetForwarder(testActor) + val f = expectMsgType[FunctionRef] + watch(f) + s ! DropForwarder(f) + expectTerminated(f) + } + + "be able to watch" in { + s ! GetForwarder(testActor) + val f = expectMsgType[FunctionRef] + forwarder.watch(f) + s ! DropForwarder(f) + expectMsg(Forwarded(Terminated(f)(true, false), null)) + } + + "terminate when their parent terminates" in { + watch(forwarder) + s ! PoisonPill + expectTerminated(forwarder) + } + } + + "A FunctionRef" when { + + "created by a toplevel actor" must { + val s = system.actorOf(Props[Super], "super") + commonTests(s) + } + + "created by a non-toplevel actor" must { + val s = system.actorOf(Props[SupSuper], "supsuper") + commonTests(s) + } + + "not registered" must { + "not be found" in { + val provider = system.asInstanceOf[ExtendedActorSystem].provider + val ref = new FunctionRef(testActor.path / "blabla", provider, system.eventStream, (x, y) ⇒ ()) + EventFilter[ClassCastException](occurrences = 1) intercept { + // needs to be something that fails when the deserialized form is not a FunctionRef + // this relies upon serialize-messages during tests + testActor ! DropForwarder(ref) + expectNoMsg(1.second) + } + } + } + + } +} diff --git a/akka-actor/src/main/java/akka/actor/dungeon/AbstractActorCell.java b/akka-actor/src/main/java/akka/actor/dungeon/AbstractActorCell.java index c852ddde9c..8b9a77b35a 100644 --- a/akka-actor/src/main/java/akka/actor/dungeon/AbstractActorCell.java +++ b/akka-actor/src/main/java/akka/actor/dungeon/AbstractActorCell.java @@ -11,12 +11,14 @@ final class AbstractActorCell { final static long mailboxOffset; final static long childrenOffset; final static long nextNameOffset; + final static long functionRefsOffset; static { try { mailboxOffset = Unsafe.instance.objectFieldOffset(ActorCell.class.getDeclaredField("akka$actor$dungeon$Dispatch$$_mailboxDoNotCallMeDirectly")); childrenOffset = Unsafe.instance.objectFieldOffset(ActorCell.class.getDeclaredField("akka$actor$dungeon$Children$$_childrenRefsDoNotCallMeDirectly")); nextNameOffset = Unsafe.instance.objectFieldOffset(ActorCell.class.getDeclaredField("akka$actor$dungeon$Children$$_nextNameDoNotCallMeDirectly")); + functionRefsOffset = Unsafe.instance.objectFieldOffset(ActorCell.class.getDeclaredField("akka$actor$dungeon$Children$$_functionRefsDoNotCallMeDirectly")); } catch(Throwable t){ throw new ExceptionInInitializerError(t); } diff --git a/akka-actor/src/main/scala/akka/actor/ActorRef.scala b/akka-actor/src/main/scala/akka/actor/ActorRef.scala index ccd439cdb8..23c6ecc297 100644 --- a/akka-actor/src/main/scala/akka/actor/ActorRef.scala +++ b/akka-actor/src/main/scala/akka/actor/ActorRef.scala @@ -12,7 +12,9 @@ import akka.serialization.{ Serialization, JavaSerializer } import akka.event.EventStream import scala.annotation.tailrec import java.util.concurrent.ConcurrentHashMap -import akka.event.LoggingAdapter +import akka.event.{ Logging, LoggingAdapter } +import java.util.concurrent.atomic.AtomicReference +import scala.util.control.NonFatal object ActorRef { @@ -687,3 +689,139 @@ private[akka] class VirtualPathContainer( while (iter.hasNext) f(iter.next) } } + +/** + * INTERNAL API + * + * This kind of ActorRef passes all received messages to the given function for + * performing a non-blocking side-effect. The intended use is to transform the + * message before sending to the real target actor. Such references can be created + * by calling `ActorCell.addFunctionRef` and must be deregistered when no longer + * needed by calling `ActorCell.removeFunctionRef`. FunctionRefs do not count + * towards the live children of an actor, they do not receive the Terminate command + * and do not prevent the parent from terminating. FunctionRef is properly + * registered for remote lookup and ActorSelection. + * + * When using the watch() feature you must ensure that upon reception of the + * Terminated message the watched actorRef is unwatch()ed. + */ +private[akka] final class FunctionRef(override val path: ActorPath, + override val provider: ActorRefProvider, + val eventStream: EventStream, + f: (ActorRef, Any) ⇒ Unit) extends MinimalActorRef { + + override def !(message: Any)(implicit sender: ActorRef = Actor.noSender): Unit = { + f(sender, message) + } + + override def sendSystemMessage(message: SystemMessage): Unit = { + message match { + case w: Watch ⇒ addWatcher(w.watchee, w.watcher) + case u: Unwatch ⇒ remWatcher(u.watchee, u.watcher) + case DeathWatchNotification(actorRef, _, _) ⇒ + this.!(Terminated(actorRef)(existenceConfirmed = true, addressTerminated = false)) + case _ ⇒ //ignore all other messages + } + } + + private[this] var watching = ActorCell.emptyActorRefSet + private[this] val _watchedBy = new AtomicReference[Set[ActorRef]](ActorCell.emptyActorRefSet) + + override def isTerminated = _watchedBy.get() == null + + //noinspection EmptyCheck + protected def sendTerminated(): Unit = { + val watchedBy = _watchedBy.getAndSet(null) + if (watchedBy != null) { + if (watchedBy.nonEmpty) { + watchedBy foreach sendTerminated(ifLocal = false) + watchedBy foreach sendTerminated(ifLocal = true) + } + if (watching.nonEmpty) { + watching foreach unwatchWatched + watching = Set.empty + } + } + } + + private def sendTerminated(ifLocal: Boolean)(watcher: ActorRef): Unit = + if (watcher.asInstanceOf[ActorRefScope].isLocal == ifLocal) + watcher.asInstanceOf[InternalActorRef].sendSystemMessage(DeathWatchNotification(this, existenceConfirmed = true, addressTerminated = false)) + + private def unwatchWatched(watched: ActorRef): Unit = + watched.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(watched, this)) + + override def stop(): Unit = sendTerminated() + + @tailrec private def addWatcher(watchee: ActorRef, watcher: ActorRef): Unit = + _watchedBy.get() match { + case null ⇒ + sendTerminated(ifLocal = true)(watcher) + sendTerminated(ifLocal = false)(watcher) + + case watchedBy ⇒ + val watcheeSelf = watchee == this + val watcherSelf = watcher == this + + if (watcheeSelf && !watcherSelf) { + if (!watchedBy.contains(watcher)) + if (!_watchedBy.compareAndSet(watchedBy, watchedBy + watcher)) + addWatcher(watchee, watcher) // try again + } else if (!watcheeSelf && watcherSelf) { + publish(Logging.Warning(path.toString, classOf[FunctionRef], s"externally triggered watch from $watcher to $watchee is illegal on FunctionRef")) + } else { + publish(Logging.Error(path.toString, classOf[FunctionRef], s"BUG: illegal Watch($watchee,$watcher) for $this")) + } + } + + @tailrec private def remWatcher(watchee: ActorRef, watcher: ActorRef): Unit = { + _watchedBy.get() match { + case null ⇒ // do nothing... + case watchedBy ⇒ + val watcheeSelf = watchee == this + val watcherSelf = watcher == this + + if (watcheeSelf && !watcherSelf) { + if (watchedBy.contains(watcher)) + if (!_watchedBy.compareAndSet(watchedBy, watchedBy - watcher)) + remWatcher(watchee, watcher) // try again + } else if (!watcheeSelf && watcherSelf) { + publish(Logging.Warning(path.toString, classOf[FunctionRef], s"externally triggered unwatch from $watcher to $watchee is illegal on FunctionRef")) + } else { + publish(Logging.Error(path.toString, classOf[FunctionRef], s"BUG: illegal Unwatch($watchee,$watcher) for $this")) + } + } + } + + private def publish(e: Logging.LogEvent): Unit = try eventStream.publish(e) catch { case NonFatal(_) ⇒ } + + /** + * Have this FunctionRef watch the given Actor. This method must not be + * called concurrently from different threads, it should only be called by + * its parent Actor. + * + * Upon receiving the Terminated message, unwatch() must be called from a + * safe context (i.e. normally from the parent Actor). + */ + def watch(actorRef: ActorRef): Unit = { + watching += actorRef + actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Watch(actorRef.asInstanceOf[InternalActorRef], this)) + } + + /** + * Have this FunctionRef unwatch the given Actor. This method must not be + * called concurrently from different threads, it should only be called by + * its parent Actor. + */ + def unwatch(actorRef: ActorRef): Unit = { + watching -= actorRef + actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(actorRef.asInstanceOf[InternalActorRef], this)) + } + + /** + * Query whether this FunctionRef is currently watching the given Actor. This + * method must not be called concurrently from different threads, it should + * only be called by its parent Actor. + */ + def isWatching(actorRef: ActorRef): Boolean = watching.contains(actorRef) +} diff --git a/akka-actor/src/main/scala/akka/actor/RepointableActorRef.scala b/akka-actor/src/main/scala/akka/actor/RepointableActorRef.scala index 7b24c015e2..0e561a75c4 100644 --- a/akka-actor/src/main/scala/akka/actor/RepointableActorRef.scala +++ b/akka-actor/src/main/scala/akka/actor/RepointableActorRef.scala @@ -150,7 +150,10 @@ private[akka] class RepointableActorRef( lookup.getChildByName(childName) match { case Some(crs: ChildRestartStats) if uid == ActorCell.undefinedUid || uid == crs.uid ⇒ crs.child.asInstanceOf[InternalActorRef].getChild(name) - case _ ⇒ Nobody + case _ ⇒ lookup match { + case ac: ActorCell ⇒ ac.getFunctionRefOrNobody(childName, uid) + case _ ⇒ Nobody + } } } } else this diff --git a/akka-actor/src/main/scala/akka/actor/dungeon/Children.scala b/akka-actor/src/main/scala/akka/actor/dungeon/Children.scala index f14313de84..8698c88d78 100644 --- a/akka-actor/src/main/scala/akka/actor/dungeon/Children.scala +++ b/akka-actor/src/main/scala/akka/actor/dungeon/Children.scala @@ -12,6 +12,10 @@ import akka.serialization.SerializationExtension import akka.util.{ Unsafe, Helpers } import akka.serialization.SerializerWithStringManifest +private[akka] object Children { + val GetNobody = () ⇒ Nobody +} + private[akka] trait Children { this: ActorCell ⇒ import ChildrenContainer._ @@ -41,14 +45,63 @@ private[akka] trait Children { this: ActorCell ⇒ private[akka] def attachChild(props: Props, name: String, systemService: Boolean): ActorRef = makeChild(this, props, checkName(name), async = true, systemService = systemService) - @volatile private var _nextNameDoNotCallMeDirectly = 0L - final protected def randomName(): String = { - @tailrec def inc(): Long = { - val current = Unsafe.instance.getLongVolatile(this, AbstractActorCell.nextNameOffset) - if (Unsafe.instance.compareAndSwapLong(this, AbstractActorCell.nextNameOffset, current, current + 1)) current - else inc() + @volatile private var _functionRefsDoNotCallMeDirectly = Map.empty[String, FunctionRef] + private def functionRefs: Map[String, FunctionRef] = + Unsafe.instance.getObjectVolatile(this, AbstractActorCell.functionRefsOffset).asInstanceOf[Map[String, FunctionRef]] + + private[akka] def getFunctionRefOrNobody(name: String, uid: Int = ActorCell.undefinedUid): InternalActorRef = + functionRefs.getOrElse(name, Children.GetNobody()) match { + case f: FunctionRef ⇒ + if (uid == ActorCell.undefinedUid || f.path.uid == uid) f else Nobody + case other ⇒ + other } - Helpers.base64(inc()) + + private[akka] def addFunctionRef(f: (ActorRef, Any) ⇒ Unit): FunctionRef = { + val childPath = new ChildActorPath(self.path, randomName(new java.lang.StringBuilder("$$")), ActorCell.newUid()) + val ref = new FunctionRef(childPath, provider, system.eventStream, f) + + @tailrec def rec(): Unit = { + val old = functionRefs + val added = old.updated(childPath.name, ref) + if (!Unsafe.instance.compareAndSwapObject(this, AbstractActorCell.functionRefsOffset, old, added)) rec() + } + rec() + + ref + } + + private[akka] def removeFunctionRef(ref: FunctionRef): Boolean = { + require(ref.path.parent eq self.path, "trying to remove FunctionRef from wrong ActorCell") + val name = ref.path.name + @tailrec def rec(): Boolean = { + val old = functionRefs + if (!old.contains(name)) false + else { + val removed = old - name + if (!Unsafe.instance.compareAndSwapObject(this, AbstractActorCell.functionRefsOffset, old, removed)) rec() + else { + ref.stop() + true + } + } + } + rec() + } + + protected def stopFunctionRefs(): Unit = { + val refs = Unsafe.instance.getAndSetObject(this, AbstractActorCell.functionRefsOffset, Map.empty).asInstanceOf[Map[String, FunctionRef]] + refs.valuesIterator.foreach(_.stop()) + } + + @volatile private var _nextNameDoNotCallMeDirectly = 0L + final protected def randomName(sb: java.lang.StringBuilder): String = { + val num = Unsafe.instance.getAndAddLong(this, AbstractActorCell.nextNameOffset, 1) + Helpers.base64(num, sb) + } + final protected def randomName(): String = { + val num = Unsafe.instance.getAndAddLong(this, AbstractActorCell.nextNameOffset, 1) + Helpers.base64(num) } final def stop(actor: ActorRef): Unit = { @@ -140,14 +193,14 @@ private[akka] trait Children { this: ActorCell ⇒ // optimization for the non-uid case getChildByName(name) match { case Some(crs: ChildRestartStats) ⇒ crs.child.asInstanceOf[InternalActorRef] - case _ ⇒ Nobody + case _ ⇒ getFunctionRefOrNobody(name) } } else { val (childName, uid) = ActorCell.splitNameAndUid(name) getChildByName(childName) match { case Some(crs: ChildRestartStats) if uid == ActorCell.undefinedUid || uid == crs.uid ⇒ crs.child.asInstanceOf[InternalActorRef] - case _ ⇒ Nobody + case _ ⇒ getFunctionRefOrNobody(childName, uid) } } diff --git a/akka-actor/src/main/scala/akka/actor/dungeon/FaultHandling.scala b/akka-actor/src/main/scala/akka/actor/dungeon/FaultHandling.scala index 1bc158a01c..b06c57fc3c 100644 --- a/akka-actor/src/main/scala/akka/actor/dungeon/FaultHandling.scala +++ b/akka-actor/src/main/scala/akka/actor/dungeon/FaultHandling.scala @@ -211,6 +211,7 @@ private[akka] trait FaultHandling { this: ActorCell ⇒ catch handleNonFatalOrInterruptedException { e ⇒ publish(Error(e, self.path.toString, clazz(a), e.getMessage)) } finally try dispatcher.detach(this) finally try parent.sendSystemMessage(DeathWatchNotification(self, existenceConfirmed = true, addressTerminated = false)) + finally try stopFunctionRefs() finally try tellWatchersWeDied() finally try unwatchWatchedActors(a) // stay here as we expect an emergency stop from handleInvokeFailure finally { diff --git a/akka-actor/src/main/scala/akka/dispatch/AbstractDispatcher.scala b/akka-actor/src/main/scala/akka/dispatch/AbstractDispatcher.scala index 9a5a0d4bb4..5cda6f9fb3 100644 --- a/akka-actor/src/main/scala/akka/dispatch/AbstractDispatcher.scala +++ b/akka-actor/src/main/scala/akka/dispatch/AbstractDispatcher.scala @@ -92,17 +92,17 @@ abstract class MessageDispatcher(val configurator: MessageDispatcherConfigurator @volatile private[this] var _inhabitantsDoNotCallMeDirectly: Long = _ // DO NOT TOUCH! @volatile private[this] var _shutdownScheduleDoNotCallMeDirectly: Int = _ // DO NOT TOUCH! - @tailrec private final def addInhabitants(add: Long): Long = { - val c = inhabitants - val r = c + add - if (r < 0) { + private final def addInhabitants(add: Long): Long = { + val old = Unsafe.instance.getAndAddLong(this, inhabitantsOffset, add) + val ret = old + add + if (ret < 0) { // We haven't succeeded in decreasing the inhabitants yet but the simple fact that we're trying to // go below zero means that there is an imbalance and we might as well throw the exception val e = new IllegalStateException("ACTOR SYSTEM CORRUPTED!!! A dispatcher can't have less than 0 inhabitants!") reportFailure(e) throw e } - if (Unsafe.instance.compareAndSwapLong(this, inhabitantsOffset, c, r)) r else addInhabitants(add) + ret } final def inhabitants: Long = Unsafe.instance.getLongVolatile(this, inhabitantsOffset) diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala index 9bdc027770..73fe59c984 100644 --- a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala @@ -22,7 +22,7 @@ class StreamTestKitSpec extends AkkaSpec { } "#toStrict with failing source" in { - val msg = intercept[AssertionError] { + val error = intercept[AssertionError] { Source.fromIterator(() ⇒ new Iterator[Int] { var i = 0 override def hasNext: Boolean = true @@ -35,10 +35,10 @@ class StreamTestKitSpec extends AkkaSpec { } }).runWith(TestSink.probe) .toStrict(300.millis) - }.getMessage + } - msg should include("Boom!") - msg should include("List(1, 2)") + error.getCause.getMessage should include("Boom!") + error.getMessage should include("List(1, 2)") } "#toStrict when subscription was already obtained" in { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StageActorRefSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StageActorRefSpec.scala index 3d866e7221..669a28d125 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StageActorRefSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StageActorRefSpec.scala @@ -186,23 +186,23 @@ object StageActorRefSpec { val p: Promise[Int] = Promise() val logic = new GraphStageLogic(shape) { - implicit def self = stageActorRef // must be a `def`, we want self to be the sender for our replies + implicit def self = stageActor.ref // must be a `def`; we want self to be the sender for our replies var sum: Int = 0 override def preStart(): Unit = { pull(in) - probe ! getStageActorRef(behaviour) + probe ! getStageActor(behaviour).ref } def behaviour(m: (ActorRef, Any)): Unit = { m match { case (sender, Add(n)) ⇒ sum += n case (sender, PullNow) ⇒ pull(in) - case (sender, CallInitStageActorRef) ⇒ sender ! getStageActorRef(behaviour) + case (sender, CallInitStageActorRef) ⇒ sender ! getStageActor(behaviour).ref case (sender, BecomeStringEcho) ⇒ - getStageActorRef({ + getStageActor { case (theSender, msg) ⇒ theSender ! msg.toString - }) + } case (sender, StopNow) ⇒ p.trySuccess(sum) completeStage() @@ -235,4 +235,4 @@ object StageActorRefSpec { } } -} \ No newline at end of file +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala index 8e7aa10f32..4d6087eeff 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorRefBackpressureSinkStage.scala @@ -7,7 +7,7 @@ import java.util import akka.actor._ import akka.dispatch.sysmsg.{ DeathWatchNotification, SystemMessage, Watch } -import akka.stream.stage.GraphStageLogic.StageActorRef +import akka.stream.stage.GraphStageLogic.StageActor import akka.stream.{ Inlet, SinkShape, ActorMaterializer, Attributes } import akka.stream.Attributes.InputBuffer import akka.stream.stage._ @@ -28,7 +28,7 @@ private[akka] class ActorRefBackpressureSinkStage[In](ref: ActorRef, onInitMessa override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - implicit var self: StageActorRef = _ + implicit def self: ActorRef = stageActor.ref val buffer: util.Deque[In] = new util.ArrayDeque[In]() var acknowledgementReceived = false @@ -46,8 +46,7 @@ private[akka] class ActorRefBackpressureSinkStage[In](ref: ActorRef, onInitMessa override def preStart() = { setKeepGoing(true) - self = getStageActorRef(receive) - self.watch(ref) + getStageActor(receive).watch(ref) ref ! onInitMessage pull(in) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index 48c7de3a34..18b324244a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -16,7 +16,7 @@ import akka.stream.impl.ReactiveStreamsCompliance import akka.stream.impl.fusing.GraphStages.detacher import akka.stream.scaladsl.Tcp.{ OutgoingConnection, ServerBinding } import akka.stream.scaladsl.{ BidiFlow, Flow, Tcp ⇒ StreamTcp } -import akka.stream.stage.GraphStageLogic.StageActorRef +import akka.stream.stage.GraphStageLogic.StageActor import akka.stream.stage._ import akka.util.ByteString @@ -48,12 +48,12 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, val bindingPromise = Promise[ServerBinding] val logic = new TimerGraphStageLogic(shape) { - implicit var self: StageActorRef = _ + implicit def self: ActorRef = stageActor.ref var listener: ActorRef = _ var unbindPromise = Promise[Unit]() override def preStart(): Unit = { - self = getStageActorRef(receive) + getStageActor(receive) tcpManager ! Tcp.Bind(self, endpoint, backlog, options, pullMode = true) } @@ -63,7 +63,7 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, msg match { case Bound(localAddress) ⇒ listener = sender - self.watch(listener) + stageActor.watch(listener) if (isAvailable(out)) listener ! ResumeAccepting(1) val target = self bindingPromise.success(ServerBinding(localAddress)(() ⇒ { target ! Unbind; unbindPromise.future })) @@ -118,7 +118,7 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, private def tryUnbind(): Unit = { if (listener ne null) { - self.unwatch(listener) + stageActor.unwatch(listener) setKeepGoing(true) listener ! Unbind } @@ -169,7 +169,7 @@ private[stream] object TcpConnectionStage { * easier to maintain and understand. */ class TcpStreamLogic(val shape: FlowShape[ByteString, ByteString], val role: TcpRole) extends GraphStageLogic(shape) { - implicit private var self: StageActorRef = _ + implicit def self: ActorRef = stageActor.ref private def bytesIn = shape.in private def bytesOut = shape.out @@ -185,14 +185,12 @@ private[stream] object TcpConnectionStage { role match { case Inbound(conn, _) ⇒ setHandler(bytesOut, readHandler) - self = getStageActorRef(connected) connection = conn - self.watch(connection) + getStageActor(connected).watch(connection) connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) pull(bytesIn) case ob @ Outbound(manager, cmd, _, _) ⇒ - self = getStageActorRef(connecting(ob)) - self.watch(manager) + getStageActor(connecting(ob)).watch(manager) manager ! cmd } } @@ -207,9 +205,9 @@ private[stream] object TcpConnectionStage { role.asInstanceOf[Outbound].localAddressPromise.success(c.localAddress) connection = sender setHandler(bytesOut, readHandler) - self.unwatch(ob.manager) - self = getStageActorRef(connected) - self.watch(connection) + stageActor.unwatch(ob.manager) + stageActor.become(connected) + stageActor.watch(connection) connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) if (isAvailable(bytesOut)) connection ! ResumeReading pull(bytesIn) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index ebb250c696..b314a05989 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -128,14 +128,34 @@ object GraphStageLogic { /** * Minimal actor to work with other actors and watch them in a synchronous ways */ - final class StageActorRef(val provider: ActorRefProvider, val log: LoggingAdapter, - getAsyncCallback: StageActorRef.Receive ⇒ AsyncCallback[(ActorRef, Any)], - initialReceive: StageActorRef.Receive, - override val path: ActorPath) extends akka.actor.MinimalActorRef { - import StageActorRef._ + final class StageActor(materializer: ActorMaterializer, + getAsyncCallback: StageActorRef.Receive ⇒ AsyncCallback[(ActorRef, Any)], + initialReceive: StageActorRef.Receive) { private val callback = getAsyncCallback(internalReceive) + private val functionRef: FunctionRef = { + val cell = materializer.supervisor match { + case ref: LocalActorRef ⇒ ref.underlying + case ref: RepointableActorRef if ref.isStarted ⇒ ref.underlying.asInstanceOf[ActorCell] + case unknown ⇒ + throw new IllegalStateException(s"Stream supervisor must be a local actor, was [${unknown.getClass.getName}]") + } + cell.addFunctionRef { + case (_, m @ (PoisonPill | Kill)) ⇒ + materializer.logger.warning("{} message sent to StageActor({}) will be ignored, since it is not a real Actor." + + "Use a custom message type to communicate with it instead.", m, functionRef.path) + case pair ⇒ callback.invoke(pair) + } + } + + /** + * The ActorRef by which this StageActor can be contacted from the outside. + * This is a full-fledged ActorRef that supports watching and being watched + * as well as location transparent (remote) communication. + */ + def ref: ActorRef = functionRef + @volatile private[this] var behaviour = initialReceive @@ -143,32 +163,14 @@ object GraphStageLogic { private[akka] def internalReceive(pack: (ActorRef, Any)): Unit = { pack._2 match { case Terminated(ref) ⇒ - if (watching contains ref) { - watching -= ref + if (functionRef.isWatching(ref)) { + functionRef.unwatch(ref) behaviour(pack) } case _ ⇒ behaviour(pack) } } - override def !(message: Any)(implicit sender: ActorRef = Actor.noSender): Unit = { - message match { - case m @ (PoisonPill | Kill) ⇒ - log.warning("{} message sent to StageActorRef({}) will be ignored, since it is not a real Actor." + - "Use a custom message type to communicate with it instead.", m, path) - case _ ⇒ - callback.invoke((sender, message)) - } - } - - override def sendSystemMessage(message: SystemMessage): Unit = message match { - case w: Watch ⇒ addWatcher(w.watchee, w.watcher) - case u: Unwatch ⇒ remWatcher(u.watchee, u.watcher) - case DeathWatchNotification(actorRef, _, _) ⇒ - this.!(Terminated(actorRef)(existenceConfirmed = true, addressTerminated = false)) - case _ ⇒ //ignore all other messages - } - /** * Special `become` allowing to swap the behaviour of this StageActorRef. * Unbecome is not available. @@ -177,92 +179,14 @@ object GraphStageLogic { behaviour = receive } - private[this] var watching = ActorCell.emptyActorRefSet - private[this] val _watchedBy = new AtomicReference[Set[ActorRef]](ActorCell.emptyActorRefSet) + def stop(): Unit = functionRef.stop() - override def isTerminated = _watchedBy.get() == StageTerminatedTombstone + def watch(actorRef: ActorRef): Unit = functionRef.watch(actorRef) - //noinspection EmptyCheck - protected def sendTerminated(): Unit = { - val watchedBy = _watchedBy.getAndSet(StageTerminatedTombstone) - if (watchedBy != StageTerminatedTombstone) { - if (watchedBy.nonEmpty) { - watchedBy foreach sendTerminated(ifLocal = false) - watchedBy foreach sendTerminated(ifLocal = true) - } - if (watching.nonEmpty) { - watching foreach unwatchWatched - watching = Set.empty - } - } - } - - private def sendTerminated(ifLocal: Boolean)(watcher: ActorRef): Unit = - if (watcher.asInstanceOf[ActorRefScope].isLocal == ifLocal) - watcher.asInstanceOf[InternalActorRef].sendSystemMessage(DeathWatchNotification(this, existenceConfirmed = true, addressTerminated = false)) - - private def unwatchWatched(watched: ActorRef): Unit = - watched.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(watched, this)) - - override def stop(): Unit = sendTerminated() - - @tailrec final def addWatcher(watchee: ActorRef, watcher: ActorRef): Unit = - _watchedBy.get() match { - case StageTerminatedTombstone ⇒ - sendTerminated(ifLocal = true)(watcher) - sendTerminated(ifLocal = false)(watcher) - - case watchedBy ⇒ - val watcheeSelf = watchee == this - val watcherSelf = watcher == this - - if (watcheeSelf && !watcherSelf) { - if (!watchedBy.contains(watcher)) - if (!_watchedBy.compareAndSet(watchedBy, watchedBy + watcher)) - addWatcher(watchee, watcher) // try again - } else if (!watcheeSelf && watcherSelf) { - log.warning("externally triggered watch from {} to {} is illegal on StageActorRef", watcher, watchee) - } else { - log.error("BUG: illegal Watch(%s,%s) for %s".format(watchee, watcher, this)) - } - } - - @tailrec final def remWatcher(watchee: ActorRef, watcher: ActorRef): Unit = { - _watchedBy.get() match { - case StageTerminatedTombstone ⇒ // do nothing... - case watchedBy ⇒ - val watcheeSelf = watchee == this - val watcherSelf = watcher == this - - if (watcheeSelf && !watcherSelf) { - if (watchedBy.contains(watcher)) - if (!_watchedBy.compareAndSet(watchedBy, watchedBy - watcher)) - remWatcher(watchee, watcher) // try again - } else if (!watcheeSelf && watcherSelf) { - log.warning("externally triggered unwatch from {} to {} is illegal on StageActorRef", watcher, watchee) - } else { - log.error("BUG: illegal Unwatch(%s,%s) for %s".format(watchee, watcher, this)) - } - } - } - - def watch(actorRef: ActorRef): Unit = { - watching += actorRef - actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Watch(actorRef.asInstanceOf[InternalActorRef], this)) - } - - def unwatch(actorRef: ActorRef): Unit = { - watching -= actorRef - actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(actorRef.asInstanceOf[InternalActorRef], this)) - } + def unwatch(actorRef: ActorRef): Unit = functionRef.unwatch(actorRef) } object StageActorRef { type Receive = ((ActorRef, Any)) ⇒ Unit - - val StageTerminatedTombstone = null - - // globally sequential, one should not depend on these names in any case - val name = SeqActorName("StageActorRef") } } @@ -950,8 +874,8 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: final protected def createAsyncCallback[T](handler: Procedure[T]): AsyncCallback[T] = getAsyncCallback(handler.apply) - private var _stageActorRef: StageActorRef = _ - final def stageActorRef: ActorRef = _stageActorRef match { + private var _stageActor: StageActor = _ + final def stageActor: StageActor = _stageActor match { case null ⇒ throw StageActorRefNotInitializedException() case ref ⇒ ref } @@ -974,14 +898,12 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * @return minimal actor with watch method */ // FIXME: I don't like the Pair allocation :( - final protected def getStageActorRef(receive: ((ActorRef, Any)) ⇒ Unit): StageActorRef = { - _stageActorRef match { + final protected def getStageActor(receive: ((ActorRef, Any)) ⇒ Unit): StageActor = { + _stageActor match { case null ⇒ val actorMaterializer = ActorMaterializer.downcast(interpreter.materializer) - val provider = actorMaterializer.supervisor.asInstanceOf[InternalActorRef].provider - val path = actorMaterializer.supervisor.path / StageActorRef.name.next() - _stageActorRef = new StageActorRef(provider, actorMaterializer.logger, getAsyncCallback, receive, path) - _stageActorRef + _stageActor = new StageActor(actorMaterializer, getAsyncCallback, receive) + _stageActor case existing ⇒ existing.become(receive) existing @@ -995,9 +917,9 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // Internal hooks to avoid reliance on user calling super in postStop /** INTERNAL API */ protected[stream] def afterPostStop(): Unit = { - if (_stageActorRef ne null) { - _stageActorRef.stop() - _stageActorRef = null + if (_stageActor ne null) { + _stageActor.stop() + _stageActor = null } } diff --git a/akka-typed/src/main/scala/akka/typed/Impl.scala b/akka-typed/src/main/scala/akka/typed/Impl.scala index f35dadfa38..173fe51a95 100644 --- a/akka-typed/src/main/scala/akka/typed/Impl.scala +++ b/akka-typed/src/main/scala/akka/typed/Impl.scala @@ -82,12 +82,20 @@ private[typed] class ActorContextAdapter[T](ctx: akka.actor.ActorContext) extend def spawn[U](props: Props[U], name: String) = ctx.spawn(props, name) def actorOf(props: a.Props) = ctx.actorOf(props) def actorOf(props: a.Props, name: String) = ctx.actorOf(props, name) - def stop(child: ActorRef[Nothing]) = ctx.child(child.path.name) match { - case Some(ref) if ref == child.untypedRef ⇒ - ctx.stop(child.untypedRef) - true - case _ ⇒ false // none of our business - } + def stop(child: ActorRef[Nothing]) = + child.untypedRef match { + case f: akka.actor.FunctionRef ⇒ + val cell = ctx.asInstanceOf[akka.actor.ActorCell] + cell.removeFunctionRef(f) + case _ ⇒ + ctx.child(child.path.name) match { + case Some(ref) if ref == child.untypedRef ⇒ + ctx.stop(child.untypedRef) + true + case _ ⇒ + false // none of our business + } + } def watch[U](other: ActorRef[U]) = { ctx.watch(other.untypedRef); other } def watch(other: a.ActorRef) = { ctx.watch(other); other } def unwatch[U](other: ActorRef[U]) = { ctx.unwatch(other.untypedRef); other } @@ -98,7 +106,11 @@ private[typed] class ActorContextAdapter[T](ctx: akka.actor.ActorContext) extend import ctx.dispatcher ctx.system.scheduler.scheduleOnce(delay, target.untypedRef, msg) } - def spawnAdapter[U](f: U ⇒ T) = ActorRef[U](ctx.actorOf(akka.actor.Props(classOf[MessageWrapper], f))) + def spawnAdapter[U](f: U ⇒ T) = { + val cell = ctx.asInstanceOf[akka.actor.ActorCell] + val ref = cell.addFunctionRef((_, msg) ⇒ ctx.self ! f(msg.asInstanceOf[U])) + ActorRef[U](ref) + } } /** diff --git a/akka-typed/src/test/scala/akka/typed/ActorContextSpec.scala b/akka-typed/src/test/scala/akka/typed/ActorContextSpec.scala index 756a6dc2b1..708cfdc175 100644 --- a/akka-typed/src/test/scala/akka/typed/ActorContextSpec.scala +++ b/akka-typed/src/test/scala/akka/typed/ActorContextSpec.scala @@ -66,6 +66,9 @@ object ActorContextSpec { final case class BecomeCareless(replyTo: ActorRef[BecameCareless.type]) extends Command case object BecameCareless extends Event + final case class GetAdapter(replyTo: ActorRef[Adapter]) extends Command + final case class Adapter(a: ActorRef[Command]) extends Event + def subject(monitor: ActorRef[GotSignal]): Behavior[Command] = FullTotal { case Sig(ctx, signal) ⇒ @@ -142,6 +145,9 @@ object ActorContextSpec { monitor ! GotSignal(sig) Same } + case GetAdapter(replyTo) ⇒ + replyTo ! Adapter(ctx.spawnAdapter(identity)) + Same } } } @@ -503,6 +509,26 @@ class ActorContextSpec extends TypedSpec(ConfigFactory.parseString( msgs should ===(Scheduled :: Pong2 :: Nil) } }) + + def `40 must create a working adapter`(): Unit = sync(setup("ctx40") { (ctx, startWith) ⇒ + startWith.keep { subj ⇒ + subj ! GetAdapter(ctx.self) + }.expectMessage(500.millis) { (msg, subj) ⇒ + val Adapter(adapter) = msg + ctx.watch(adapter) + adapter ! Ping(ctx.self) + (subj, adapter) + }.expectMessage(500.millis) { + case (msg, (subj, adapter)) ⇒ + msg should ===(Pong1) + ctx.stop(subj) + adapter + }.expectMessageKeep(500.millis) { (msg, _) ⇒ + msg should ===(GotSignal(PostStop)) + }.expectTermination(500.millis) { (t, adapter) ⇒ + t.ref should ===(adapter) + } + }) } object `An ActorContext` extends Tests { diff --git a/project/MiMa.scala b/project/MiMa.scala index aae3626cef..8da1aadbfe 100644 --- a/project/MiMa.scala +++ b/project/MiMa.scala @@ -604,6 +604,9 @@ object MiMa extends AutoPlugin { ProblemFilters.exclude[MissingMethodProblem]("akka.pattern.BackoffSupervisor.akka$pattern$BackoffSupervisor$$restartCount"), ProblemFilters.exclude[MissingMethodProblem]("akka.pattern.BackoffSupervisor.akka$pattern$BackoffSupervisor$$restartCount_="), ProblemFilters.exclude[MissingMethodProblem]("akka.pattern.BackoffSupervisor.akka$pattern$BackoffSupervisor$$child") + ), + "2.4.1" -> Seq( + FilterAnyProblem("akka.actor.dungeon.Children") ) ) }