From d93143042bdbeb879a3e2725937ec791ef0de175 Mon Sep 17 00:00:00 2001 From: Viktor Klang Date: Mon, 23 Apr 2012 10:45:59 +0200 Subject: [PATCH] Adding TypedActor.context for the lifecycle methods --- .../scala/akka/actor/TypedActorSpec.scala | 21 ++++++++---- .../main/scala/akka/actor/TypedActor.scala | 32 ++++++++++++------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/akka-actor-tests/src/test/scala/akka/actor/TypedActorSpec.scala b/akka-actor-tests/src/test/scala/akka/actor/TypedActorSpec.scala index 26f510d08a..9440c18fc3 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/TypedActorSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/TypedActorSpec.scala @@ -19,6 +19,7 @@ import akka.dispatch.{ Await, Dispatchers, Future, Promise } import akka.pattern.ask import akka.serialization.JavaSerializer import akka.actor.TypedActor._ +import java.lang.IllegalStateException object TypedActorSpec { @@ -162,20 +163,26 @@ object TypedActorSpec { class LifeCyclesImpl(val latch: CountDownLatch) extends PreStart with PostStop with PreRestart with PostRestart with LifeCycles with Receiver { + private def ensureContextAvailable[T](f: ⇒ T): T = TypedActor.context match { + case null ⇒ throw new IllegalStateException("TypedActor.context is null!") + case some ⇒ f + } + override def crash(): Unit = throw new IllegalStateException("Crash!") - override def preStart(): Unit = latch.countDown() + override def preStart(): Unit = ensureContextAvailable(latch.countDown()) - override def postStop(): Unit = for (i ← 1 to 3) latch.countDown() + override def postStop(): Unit = ensureContextAvailable(for (i ← 1 to 3) latch.countDown()) - override def preRestart(reason: Throwable, message: Option[Any]): Unit = for (i ← 1 to 5) latch.countDown() + override def preRestart(reason: Throwable, message: Option[Any]): Unit = ensureContextAvailable(for (i ← 1 to 5) latch.countDown()) - override def postRestart(reason: Throwable): Unit = for (i ← 1 to 7) latch.countDown() + override def postRestart(reason: Throwable): Unit = ensureContextAvailable(for (i ← 1 to 7) latch.countDown()) override def onReceive(msg: Any, sender: ActorRef): Unit = { - msg match { - case "pigdog" ⇒ sender ! "dogpig" - } + ensureContextAvailable( + msg match { + case "pigdog" ⇒ sender ! "dogpig" + }) } } } diff --git a/akka-actor/src/main/scala/akka/actor/TypedActor.scala b/akka-actor/src/main/scala/akka/actor/TypedActor.scala index 319bd10a50..f775042566 100644 --- a/akka-actor/src/main/scala/akka/actor/TypedActor.scala +++ b/akka-actor/src/main/scala/akka/actor/TypedActor.scala @@ -227,15 +227,19 @@ object TypedActor extends ExtensionId[TypedActorExtension] with ExtensionIdProvi case _ ⇒ super.supervisorStrategy } - override def preStart(): Unit = me match { - case l: PreStart ⇒ l.preStart() - case _ ⇒ super.preStart() + override def preStart(): Unit = withContext { + me match { + case l: PreStart ⇒ l.preStart() + case _ ⇒ super.preStart() + } } override def postStop(): Unit = try { - me match { - case l: PostStop ⇒ l.postStop() - case _ ⇒ super.postStop() + withContext { + me match { + case l: PostStop ⇒ l.postStop() + case _ ⇒ super.postStop() + } } } finally { TypedActor(context.system).invocationHandlerFor(proxyVar.get) match { @@ -246,14 +250,18 @@ object TypedActor extends ExtensionId[TypedActorExtension] with ExtensionIdProvi } } - override def preRestart(reason: Throwable, message: Option[Any]): Unit = me match { - case l: PreRestart ⇒ l.preRestart(reason, message) - case _ ⇒ super.preRestart(reason, message) + override def preRestart(reason: Throwable, message: Option[Any]): Unit = withContext { + me match { + case l: PreRestart ⇒ l.preRestart(reason, message) + case _ ⇒ super.preRestart(reason, message) + } } - override def postRestart(reason: Throwable): Unit = me match { - case l: PostRestart ⇒ l.postRestart(reason) - case _ ⇒ super.postRestart(reason) + override def postRestart(reason: Throwable): Unit = withContext { + me match { + case l: PostRestart ⇒ l.postRestart(reason) + case _ ⇒ super.postRestart(reason) + } } protected def withContext[T](unitOfWork: ⇒ T): T = {