Provide cancellation for CoordinatedShutdown tasks #27335

This commit is contained in:
Matthew Smedberg 2019-10-15 05:01:13 -06:00 committed by Johan Andrén
parent b5eb18a033
commit 3e71b8a8b8
5 changed files with 441 additions and 75 deletions

View file

@ -5,22 +5,19 @@
package akka.actor
import java.util
import scala.concurrent.duration._
import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
import akka.Done
import akka.testkit.{ AkkaSpec, EventFilter, TestKit, TestProbe }
import com.typesafe.config.{ Config, ConfigFactory }
import akka.actor.CoordinatedShutdown.Phase
import akka.actor.CoordinatedShutdown.UnknownReason
import akka.util.ccompat.JavaConverters._
import scala.concurrent.Promise
import java.util.concurrent.TimeoutException
import java.util.concurrent.{ Executors, TimeoutException }
import akka.ConfigurationException
import akka.dispatch.ExecutionContexts
class CoordinatedShutdownSpec
extends AkkaSpec(ConfigFactory.parseString("""
@ -150,8 +147,203 @@ class CoordinatedShutdownSpec
testActor ! "C"
Future.successful(Done)
}
Await.result(co.run(UnknownReason), remainingOrDefault)
receiveN(4) should ===(List("A", "B", "B", "C"))
whenReady(co.run(UnknownReason)) { _ =>
receiveN(4) should ===(List("A", "B", "B", "C"))
}
}
"cancel shutdown tasks" in {
import system.dispatcher
val phases = Map("a" -> emptyPhase)
val co = new CoordinatedShutdown(extSys, phases)
val probe = TestProbe()
def createTask(message: String): () => Future[Done] =
() =>
Future {
probe.ref ! message
Done
}
val task1 = co.addCancellableTask("a", "copy1")(createTask("copy1"))
val task2 = co.addCancellableTask("a", "copy2")(createTask("copy2"))
val task3 = co.addCancellableTask("a", "copy3")(createTask("copy3"))
assert(!task1.isCancelled)
assert(!task2.isCancelled)
assert(!task3.isCancelled)
task2.cancel()
assert(task2.isCancelled)
val messagesFut = Future {
probe.receiveN(2, 3.seconds).map(_.toString)
}
whenReady(co.run(UnknownReason).flatMap(_ => messagesFut), timeout(250.milliseconds)) { messages =>
messages.distinct.size shouldEqual 2
messages.foreach {
case "copy1" | "copy3" => // OK
case other => fail(s"Unexpected probe message ${other}!")
}
}
}
"re-register the same task if requested" in {
import system.dispatcher
val phases = Map("a" -> emptyPhase)
val co = new CoordinatedShutdown(extSys, phases)
val testProbe = TestProbe()
val taskName = "labor"
val task: () => Future[Done] = () =>
Future {
testProbe.ref ! taskName
Done
}
val task1 = co.addCancellableTask("a", taskName)(task)
val task2 = co.addCancellableTask("a", taskName)(task)
val task3 = co.addCancellableTask("a", taskName)(task)
List(task1, task2, task3).foreach { t =>
assert(!t.isCancelled)
}
task1.cancel()
assert(task1.isCancelled)
val messagesFut = Future {
testProbe.receiveN(2, 3.seconds).map(_.toString)
}
whenReady(co.run(UnknownReason).flatMap(_ => messagesFut), timeout(250.milliseconds)) { messages =>
messages.distinct.size shouldEqual 1
messages.head shouldEqual taskName
}
}
"honor registration and cancellation in later phases" in {
import system.dispatcher
val phases = Map("a" -> emptyPhase, "b" -> phase("a"))
val co = new CoordinatedShutdown(extSys, phases)
val testProbe = TestProbe()
object TaskAB {
val taskA: Cancellable = co.addCancellableTask("a", "taskA") { () =>
Future {
taskB.cancel()
testProbe.ref ! "A cancels B"
Done
}
}
val taskB: Cancellable = co.addCancellableTask("b", "taskB") { () =>
Future {
taskA.cancel()
testProbe.ref ! "B cancels A"
Done
}
}
}
co.addCancellableTask("a", "taskA") { () =>
Future {
co.addCancellableTask("b", "dependentTaskB") { () =>
Future {
testProbe.ref ! "A adds B"
Done
}
}
Done
}
}
co.addCancellableTask("a", "taskA") { () =>
Future {
co.addCancellableTask("a", "dependentTaskA") { () =>
Future {
testProbe.ref ! "A adds A"
Done
}
}
Done
}
}
co.addCancellableTask("b", "taskB") { () =>
Future {
co.addCancellableTask("a", "dependentTaskA") { () =>
Future {
testProbe.ref ! "B adds A"
Done
}
}
Done
}
}
List(TaskAB.taskA, TaskAB.taskB).foreach { t =>
t.isCancelled shouldBe false
}
val messagesFut = Future {
testProbe.receiveN(2, 3.seconds).map(_.toString)
}
whenReady(co.run(UnknownReason).flatMap(_ => messagesFut), timeout(250.milliseconds)) { messages =>
messages.toSet shouldEqual Set("A adds B", "A cancels B")
}
}
"cancel tasks across threads" in {
val phases = Map("a" -> emptyPhase, "b" -> phase("a"))
val co = new CoordinatedShutdown(extSys, phases)
val testProbe = TestProbe()
val executor = Executors.newFixedThreadPool(25)
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor)
case class BMessage(content: String)
val messageA = "concurrentA"
val task: () => Future[Done] = () =>
Future {
testProbe.ref ! messageA
co.addCancellableTask("b", "concurrentB") { () =>
Future {
testProbe.ref ! BMessage("concurrentB")
Done
}(ExecutionContexts.sameThreadExecutionContext)
}
Done
}(ExecutionContexts.sameThreadExecutionContext)
val cancellationFut: Future[Done] = {
val cancellables = (0 until 20).map { _ =>
co.addCancellableTask("a", "concurrentTaskA")(task)
}
val shouldBeCancelled = cancellables.zipWithIndex.collect {
case (c, i) if i % 2 == 0 => c
}
val cancelFutures = for {
_ <- cancellables
c <- shouldBeCancelled
} yield Future {
c.cancel() shouldBe true
Done
}
cancelFutures.foldLeft(Future.successful(Done)) {
case (acc, fut) =>
acc.flatMap(_ => fut)
}
}
Await.result(cancellationFut, 250.milliseconds)
val messagesFut = Future {
testProbe.receiveN(20, 3.seconds).map(_.toString)
}
whenReady(co.run(UnknownReason).flatMap(_ => messagesFut), timeout(250.milliseconds)) { messages =>
messages.length shouldEqual 20
messages.toSet shouldEqual Set(messageA, "BMessage(concurrentB)")
}
executor.shutdown()
}
"run from a given phase" in {

View file

@ -9,23 +9,20 @@ import scala.compat.java8.FutureConverters._
import scala.compat.java8.OptionConverters._
import java.util.concurrent._
import java.util.concurrent.TimeUnit.MILLISECONDS
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
import scala.concurrent.Future
import scala.concurrent.Promise
import akka.Done
import com.typesafe.config.Config
import scala.concurrent.duration.FiniteDuration
import scala.annotation.tailrec
import com.typesafe.config.ConfigFactory
import akka.pattern.after
import scala.util.control.NonFatal
import akka.event.Logging
import akka.dispatch.ExecutionContexts
import scala.util.Try
import scala.concurrent.Await
import java.util.concurrent.atomic.AtomicReference
import java.util.function.Supplier
import java.util.Optional
@ -354,8 +351,7 @@ final class CoordinatedShutdown private[akka] (
system: ExtendedActorSystem,
phases: Map[String, CoordinatedShutdown.Phase])
extends Extension {
import CoordinatedShutdown.Reason
import CoordinatedShutdown.UnknownReason
import CoordinatedShutdown.{ Reason, UnknownReason }
/** INTERNAL API */
private[akka] val log = Logging(system, getClass)
@ -363,7 +359,141 @@ final class CoordinatedShutdown private[akka] (
/** INTERNAL API */
private[akka] val orderedPhases = CoordinatedShutdown.topologicalSort(phases)
private val tasks = new ConcurrentHashMap[String, Vector[(String, () => Future[Done])]]
private trait PhaseDefinition {
def size: Int
def run(recoverEnabled: Boolean)(implicit ec: ExecutionContext): Future[Done]
}
private object tasks {
private val registeredPhases = new ConcurrentHashMap[String, StrictPhaseDefinition]()
private trait TaskDefinition extends Cancellable {
private[tasks] def run(recoverEnabled: Boolean)(implicit ec: ExecutionContext): Future[Done]
}
private object TaskDefinition {
def apply(phaseName: String, task: () => Future[Done], name: String): TaskDefinition = new TaskDefinition {
// This is a vanilla class instead of a case class to avoid default implementations of .hashCode() and .equals
// Different code paths could register the same task under the same name multiple times; in this case we want to
// run that task as many times as it was registered (minus the number of times those were cancelled), so they must
// be distinct in a Set[TaskDefinition].
private sealed trait TaskState
private case object Pending extends TaskState
private case object Cancelled extends TaskState
private case class Running(job: Promise[Done]) extends TaskState
private val taskState = new AtomicReference[TaskState](Pending)
@tailrec
override private[tasks] def run(recoverEnabled: Boolean)(implicit ec: ExecutionContext): Future[Done] = {
val job = Promise[Done]()
val nextTaskState = taskState.updateAndGet {
case Pending => Running(job)
case Cancelled => Cancelled
case Running(otherJob) => Running(otherJob)
}
nextTaskState match {
case Running(runningJob) if runningJob == job =>
// only start the job if atomic update succeeds and we were the winner of any race
if (log.isDebugEnabled) {
log.debug("Performing task [{}] in CoordinatedShutdown phase [{}]", name, phaseName)
}
job.completeWith(try {
task.apply().recover {
case NonFatal(exc) if recoverEnabled =>
log.warning("Task [{}] failed in phase [{}]: {}", name, phaseName, exc.getMessage)
Done
}
} catch { // in case task.apply() throws
case NonFatal(exc) if recoverEnabled =>
log.warning(
"Task [{}] in phase [{}] threw an exception before its future could be constructed: {}",
name,
phaseName,
exc.getMessage)
Future.successful(Done)
case NonFatal(exc) =>
Future.failed(exc)
})
job.future
case Running(otherJob) =>
log.warning("Task [{}] in phase [{}] was invoked multiple times and deduplicated.", name, phaseName)
otherJob.future
case Cancelled =>
Future.successful(Done)
case Pending =>
log.error("Atomic update produced an impossible value; this should never happen!")
run(recoverEnabled)
}
}
override def cancel(): Boolean = {
val nextTaskState = taskState.updateAndGet {
case Pending => Cancelled
case other => other
}
nextTaskState match {
case Cancelled =>
registeredPhases
.merge(phaseName, StrictPhaseDefinition.empty, (previous, incoming) => previous.merge(incoming))
if (log.isDebugEnabled) {
log.debug("Successfully cancelled CoordinatedShutdown task [{}] from phase [{}].", name, phaseName)
}
true
case _ =>
false
}
}
// must be side-effect free
override def isCancelled: Boolean = {
taskState.get() == Cancelled
}
}
}
private case class StrictPhaseDefinition(tasks: Set[TaskDefinition]) extends PhaseDefinition {
// This is a case class so that the update methods on ConcurrentHashMap can correctly deal with equality
override val size: Int = tasks.size
override def run(recoverEnabled: Boolean)(implicit ec: ExecutionContext): Future[Done] = {
Future.sequence(tasks.map(_.run(recoverEnabled))).map(_ => Done)(ExecutionContexts.sameThreadExecutionContext)
}
// This method may be run multiple times during the compare-and-set loop of ConcurrentHashMap, so it must be side-effect-free
def merge(other: StrictPhaseDefinition): StrictPhaseDefinition = {
val nextTasks = (tasks ++ other.tasks).filterNot(_.isCancelled)
copy(tasks = nextTasks)
}
}
private object StrictPhaseDefinition {
def single(taskDefinition: TaskDefinition): StrictPhaseDefinition = {
StrictPhaseDefinition(Set(taskDefinition))
}
val empty: StrictPhaseDefinition = StrictPhaseDefinition(Set.empty)
}
def get(phaseName: String): Option[PhaseDefinition] = Option(registeredPhases.get(phaseName))
def totalDuration(): FiniteDuration = {
import akka.util.ccompat.JavaConverters._
registeredPhases.keySet.asScala.foldLeft(Duration.Zero) {
case (acc, phase) =>
acc + timeout(phase)
}
}
def register(phaseName: String, task: () => Future[Done], name: String): Cancellable = {
val cancellable: TaskDefinition = TaskDefinition(phaseName, task, name)
registeredPhases.merge(
phaseName,
StrictPhaseDefinition.single(cancellable),
(previous, incoming) => previous.merge(incoming))
cancellable
}
}
private val runStarted = new AtomicReference[Option[Reason]](None)
private val runPromise = Promise[Done]()
@ -381,6 +511,49 @@ final class CoordinatedShutdown private[akka] (
*/
private[akka] def jvmHooksLatch: CountDownLatch = _jvmHooksLatch.get
/**
* Scala API: Add a task to a phase, returning an object which will cancel it
* on demand and remove it from the task pool (so long as the same task has not
* been added elsewhere). Tasks in a phase are run concurrently, with no ordering
* assumed.
*
* Adding a task to a phase does not remove any other tasks from the phase.
*
* If the same task is added multiple times, each addition will be run unless cancelled.
*
* Tasks should typically be registered as early as possible -- once coordinated
* shutdown begins, tasks may be added without ever being run. A task may add tasks
* to a later stage with confidence that they will be run.
*/
def addCancellableTask(phase: String, taskName: String)(task: () => Future[Done]): Cancellable = {
require(
knownPhases(phase),
s"Unknown phase [$phase], known phases [$knownPhases]. All phases (along with their optional dependencies) must be defined in configuration")
require(
taskName.nonEmpty,
"Set a task name when adding tasks to the Coordinated Shutdown. " +
"Try to use unique, self-explanatory names.")
tasks.register(phase, task, taskName)
}
/**
* Java API: Add a task to a phase, returning an object which will cancel it
* on demand and remove it from the task pool (so long as the same task has not
* been added elsewhere). Tasks in a phase are run concurrently, with no ordering
* assumed.
*
* Adding a task to a phase does not remove any other tasks from the phase.
*
* If the same task is added multiple times, each addition will be run unless cancelled.
*
* Tasks should typically be registered as early as possible -- once coordinated
* shutdown begins, tasks may be added without ever being run. A task may add tasks
* to a later stage with confidence that they will be run.
*/
def addCancellableTask(phase: String, taskName: String, task: Supplier[CompletionStage[Done]]): Cancellable = {
addCancellableTask(phase, taskName)(() => task.get().toScala)
}
/**
* Scala API: Add a task to a phase. It doesn't remove previously added tasks.
* Tasks added to the same phase are executed in parallel without any
@ -393,7 +566,7 @@ final class CoordinatedShutdown private[akka] (
* It is possible to add a task to a later phase by a task in an earlier phase
* and it will be performed.
*/
@tailrec def addTask(phase: String, taskName: String)(task: () => Future[Done]): Unit = {
def addTask(phase: String, taskName: String)(task: () => Future[Done]): Unit = {
require(
knownPhases(phase),
s"Unknown phase [$phase], known phases [$knownPhases]. " +
@ -402,14 +575,7 @@ final class CoordinatedShutdown private[akka] (
taskName.nonEmpty,
"Set a task name when adding tasks to the Coordinated Shutdown. " +
"Try to use unique, self-explanatory names.")
val current = tasks.get(phase)
if (current == null) {
if (tasks.putIfAbsent(phase, Vector(taskName -> task)) != null)
addTask(phase, taskName)(task) // CAS failed, retry
} else {
if (!tasks.replace(phase, current, current :+ (taskName -> task)))
addTask(phase, taskName)(task) // CAS failed, retry
}
tasks.register(phase, task, taskName)
}
/**
@ -516,69 +682,43 @@ final class CoordinatedShutdown private[akka] (
*/
def run(reason: Reason, fromPhase: Option[String]): Future[Done] = {
if (runStarted.compareAndSet(None, Some(reason))) {
implicit val ec = system.dispatchers.internalDispatcher
implicit val ec: ExecutionContext = system.dispatchers.internalDispatcher
val debugEnabled = log.isDebugEnabled
log.debug("Running CoordinatedShutdown with reason [{}]", reason)
def loop(remainingPhases: List[String]): Future[Done] = {
remainingPhases match {
case Nil => Future.successful(Done)
case phase :: remaining if !phases(phase).enabled =>
tasks.get(phase) match {
case null => // This pretty much is ok as there are no tasks
case tasks =>
log.info("Phase [{}] disabled through configuration, skipping [{}] tasks", phase, tasks.size)
case phaseName :: remaining if !phases(phaseName).enabled =>
tasks.get(phaseName).foreach { phaseDef =>
log.info(s"Phase [{}] disabled through configuration, skipping [{}] tasks.", phaseName, phaseDef.size)
}
loop(remaining)
case phase :: remaining =>
val phaseResult = tasks.get(phase) match {
case null =>
if (debugEnabled) log.debug("Performing phase [{}] with [0] tasks", phase)
case phaseName :: remaining =>
val phaseResult = tasks.get(phaseName) match {
case None =>
if (debugEnabled) log.debug("Performing phase [{}] with [0] tasks", phaseName)
Future.successful(Done)
case tasks =>
if (debugEnabled)
log.debug(
"Performing phase [{}] with [{}] tasks: [{}]",
phase,
tasks.size,
tasks.map { case (taskName, _) => taskName }.mkString(", "))
// note that tasks within same phase are performed in parallel
val recoverEnabled = phases(phase).recover
val result = Future
.sequence(tasks.map {
case (taskName, task) =>
try {
val r = task.apply()
if (recoverEnabled) r.recover {
case NonFatal(e) =>
log.warning("Task [{}] failed in phase [{}]: {}", taskName, phase, e.getMessage)
Done
} else r
} catch {
case NonFatal(e) =>
// in case task.apply throws
if (recoverEnabled) {
log.warning("Task [{}] failed in phase [{}]: {}", taskName, phase, e.getMessage)
Future.successful(Done)
} else
Future.failed(e)
}
})
.map(_ => Done)(ExecutionContexts.sameThreadExecutionContext)
val timeout = phases(phase).timeout
case Some(phaseDef) =>
if (debugEnabled) {
log.debug("Performing phase [{}] with [{}] tasks.", phaseName, phaseDef.size)
}
val recoverEnabled = phases(phaseName).recover
val result = phaseDef.run(recoverEnabled)
val timeout = phases(phaseName).timeout
val deadline = Deadline.now + timeout
val timeoutFut = try {
after(timeout, system.scheduler) {
if (phase == CoordinatedShutdown.PhaseActorSystemTerminate && deadline.hasTimeLeft) {
if (phaseName == CoordinatedShutdown.PhaseActorSystemTerminate && deadline.hasTimeLeft) {
// too early, i.e. triggered by system termination
result
} else if (result.isCompleted)
Future.successful(Done)
else if (recoverEnabled) {
log.warning("Coordinated shutdown phase [{}] timed out after {}", phase, timeout)
log.warning("Coordinated shutdown phase [{}] timed out after {}", phaseName, timeout)
Future.successful(Done)
} else
Future.failed(
new TimeoutException(s"Coordinated shutdown phase [$phase] timed out after $timeout"))
new TimeoutException(s"Coordinated shutdown phase [$phaseName] timed out after $timeout"))
}
} catch {
case _: IllegalStateException =>
@ -638,10 +778,7 @@ final class CoordinatedShutdown private[akka] (
* Sum of timeouts of all phases that have some task.
*/
def totalTimeout(): FiniteDuration = {
import akka.util.ccompat.JavaConverters._
tasks.keySet.asScala.foldLeft(Duration.Zero) {
case (acc, phase) => acc + timeout(phase)
}
tasks.totalDuration()
}
/**

View file

@ -25,6 +25,14 @@ Scala
Java
: @@snip [ActorDocTest.java](/akka-docs/src/test/java/jdocs/actor/ActorDocTest.java) { #coordinated-shutdown-addTask }
If cancellation of previously added tasks is required:
Scala
: @@snip [ActorDocSpec.scala](/akka-docs/src/test/scala/docs/actor/ActorDocSpec.scala) { #coordinated-shutdown-cancellable }
Java
: @@snip [ActorDocTest.java](/akka-docs/src/test/java/jdocs/actor/ActorDocTest.java) { #coordinated-shutdown-cancellable }
The returned @scala[`Future[Done]`] @java[`CompletionStage<Done>`] should be completed when the task is completed. The task name parameter
is only used for debugging/logging.

View file

@ -13,7 +13,6 @@ import static jdocs.actor.Messages.Swap.Swap;
import static jdocs.actor.Messages.*;
import akka.actor.CoordinatedShutdown;
import akka.util.Timeout;
import akka.Done;
import java.util.Optional;
@ -848,6 +847,10 @@ public class ActorDocTest extends AbstractJavaTest {
};
}
private CompletionStage<Done> cleanup() {
return null;
}
@Test
public void coordinatedShutdown() {
final ActorRef someActor = system.actorOf(Props.create(FirstActor.class));
@ -862,6 +865,15 @@ public class ActorDocTest extends AbstractJavaTest {
});
// #coordinated-shutdown-addTask
// #coordinated-shutdown-cancellable
Cancellable cancellable =
CoordinatedShutdown.get(system)
.addCancellableTask(
CoordinatedShutdown.PhaseBeforeServiceUnbind(), "someTaskCleanup", () -> cleanup());
// much later...
cancellable.cancel();
// #coordinated-shutdown-cancellable
// #coordinated-shutdown-jvm-hook
CoordinatedShutdown.get(system)
.addJvmShutdownHook(() -> System.out.println("custom JVM shutdown hook..."));

View file

@ -735,6 +735,23 @@ class ActorDocSpec extends AkkaSpec("""
}
//#coordinated-shutdown-addTask
{
def cleanup(): Unit = {}
import system.dispatcher
//#coordinated-shutdown-cancellable
val c = CoordinatedShutdown(system).addCancellableTask(CoordinatedShutdown.PhaseBeforeServiceUnbind, "cleanup") {
() =>
Future {
cleanup()
Done
}
}
// much later...
c.cancel()
//#coordinated-shutdown-cancellable
}
{
val someActor = system.actorOf(Props(classOf[Replier], this))
someActor ! PoisonPill