diff --git a/akka-actor-tests/src/test/scala/akka/actor/DeathWatchSpec.scala b/akka-actor-tests/src/test/scala/akka/actor/DeathWatchSpec.scala index 2f9808744b..4506ef36ba 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/DeathWatchSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/DeathWatchSpec.scala @@ -18,6 +18,10 @@ import akka.testkit._ class LocalDeathWatchSpec extends AkkaSpec with ImplicitSender with DefaultTimeout with DeathWatchSpec object DeathWatchSpec { + object Watcher { + def props(target: ActorRef, testActor: ActorRef) = + Props(classOf[Watcher], target, testActor) + } class Watcher(target: ActorRef, testActor: ActorRef) extends Actor { context.watch(target) def receive = { @@ -26,9 +30,6 @@ object DeathWatchSpec { } } - def props(target: ActorRef, testActor: ActorRef) = - Props(classOf[Watcher], target, testActor) - class EmptyWatcher(target: ActorRef) extends Actor { context.watch(target) def receive = Actor.emptyBehavior @@ -70,6 +71,40 @@ object DeathWatchSpec { final case class FF(fail: Failed) final case class Latches(t1: TestLatch, t2: TestLatch) extends NoSerializationVerificationNeeded + + object WatchWithVerifier { + case class WatchThis(ref: ActorRef) + case object Watching + case class CustomWatchMsg(ref: ActorRef) + case class StartStashing(numberOfMessagesToStash: Int) + case object StashingStarted + + def props(probe: ActorRef) = Props(new WatchWithVerifier(probe)) + } + class WatchWithVerifier(probe: ActorRef) extends Actor with Stash { + import WatchWithVerifier._ + private var stashing = false + private var stashNMessages = 0 + + override def receive: Receive = { + case StartStashing(messagesToStash) => + stashing = true + stashNMessages = messagesToStash + sender() ! StashingStarted + case WatchThis(ref) => + context.watchWith(ref, CustomWatchMsg(ref)) + sender() ! Watching + case _ if stashing => + stash() + stashNMessages -= 1 + if (stashNMessages == 0) { + stashing = false + unstashAll() + } + case msg: CustomWatchMsg => + probe ! msg + } + } } @silent @@ -79,7 +114,8 @@ trait DeathWatchSpec { this: AkkaSpec with ImplicitSender with DefaultTimeout => lazy val supervisor = system.actorOf(Props(classOf[Supervisor], SupervisorStrategy.defaultStrategy), "watchers") - def startWatching(target: ActorRef) = Await.result((supervisor ? props(target, testActor)).mapTo[ActorRef], 3 seconds) + def startWatching(target: ActorRef) = + Await.result((supervisor ? Watcher.props(target, testActor)).mapTo[ActorRef], 3 seconds) "The Death Watch" must { def expectTerminationOf(actorRef: ActorRef) = @@ -244,6 +280,31 @@ trait DeathWatchSpec { this: AkkaSpec with ImplicitSender with DefaultTimeout => w ! Identify(()) expectMsg(ActorIdentity((), Some(w))) } + + "watch with custom message" in { + val verifierProbe = TestProbe() + val verifier = system.actorOf(WatchWithVerifier.props(verifierProbe.ref)) + val subject = system.actorOf(Props[EmptyActor]()) + verifier ! WatchWithVerifier.WatchThis(subject) + expectMsg(WatchWithVerifier.Watching) + + subject ! PoisonPill + verifierProbe.expectMsg(WatchWithVerifier.CustomWatchMsg(subject)) + } + + // Coverage for #29101 + "stash watchWith termination message correctly" in { + val verifierProbe = TestProbe() + val verifier = system.actorOf(WatchWithVerifier.props(verifierProbe.ref)) + val subject = system.actorOf(Props[EmptyActor]()) + verifier ! WatchWithVerifier.WatchThis(subject) + expectMsg(WatchWithVerifier.Watching) + verifier ! WatchWithVerifier.StartStashing(numberOfMessagesToStash = 1) + expectMsg(WatchWithVerifier.StashingStarted) + + subject ! PoisonPill + verifierProbe.expectMsg(WatchWithVerifier.CustomWatchMsg(subject)) + } } } diff --git a/akka-actor-tests/src/test/scala/akka/pattern/BackoffOnRestartSupervisorSpec.scala b/akka-actor-tests/src/test/scala/akka/pattern/BackoffOnRestartSupervisorSpec.scala index 4be5f55f34..e6fa3b5a78 100644 --- a/akka-actor-tests/src/test/scala/akka/pattern/BackoffOnRestartSupervisorSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/pattern/BackoffOnRestartSupervisorSpec.scala @@ -31,7 +31,7 @@ class TestActor(probe: ActorRef) extends Actor { probe ! "STARTED" - def receive = { + def receive: Receive = { case "DIE" => context.stop(self) case "THROW" => throw new TestActor.NormalException case "THROW_STOPPING_EXCEPTION" => throw new TestActor.StoppingException @@ -46,9 +46,9 @@ object TestParentActor { } class TestParentActor(probe: ActorRef, supervisorProps: Props) extends Actor { - val supervisor = context.actorOf(supervisorProps) + val supervisor: ActorRef = context.actorOf(supervisorProps) - def receive = { + def receive: Receive = { case other => probe.forward(other) } } @@ -58,10 +58,10 @@ class BackoffOnRestartSupervisorSpec extends AkkaSpec(""" akka.loggers = ["akka.testkit.SilenceAllTestEventListener"] """) with WithLogCapturing with ImplicitSender { - @silent def supervisorProps(probeRef: ActorRef) = { - val options = Backoff - .onFailure(TestActor.props(probeRef), "someChildName", 200 millis, 10 seconds, 0.0, maxNrOfRetries = -1) + val options = BackoffOpts + .onFailure(TestActor.props(probeRef), "someChildName", 200 millis, 10 seconds, 0.0) + .withMaxNrOfRetries(-1) .withSupervisorStrategy(OneForOneStrategy(maxNrOfRetries = 5, withinTimeRange = 30 seconds) { case _: TestActor.StoppingException => SupervisorStrategy.Stop }) @@ -69,16 +69,16 @@ class BackoffOnRestartSupervisorSpec extends AkkaSpec(""" } trait Setup { - val probe = TestProbe() - val supervisor = system.actorOf(supervisorProps(probe.ref)) + val probe: TestProbe = TestProbe() + val supervisor: ActorRef = system.actorOf(supervisorProps(probe.ref)) probe.expectMsg("STARTED") } trait Setup2 { - val probe = TestProbe() - val parent = system.actorOf(TestParentActor.props(probe.ref, supervisorProps(probe.ref))) + val probe: TestProbe = TestProbe() + val parent: ActorRef = system.actorOf(TestParentActor.props(probe.ref, supervisorProps(probe.ref))) probe.expectMsg("STARTED") - val child = probe.lastSender + val child: ActorRef = probe.lastSender } "BackoffOnRestartSupervisor" must { @@ -139,7 +139,7 @@ class BackoffOnRestartSupervisorSpec extends AkkaSpec(""" } class SlowlyFailingActor(latch: CountDownLatch) extends Actor { - def receive = { + def receive: Receive = { case "THROW" => sender ! "THROWN" throw new NormalException @@ -155,18 +155,12 @@ class BackoffOnRestartSupervisorSpec extends AkkaSpec(""" "accept commands while child is terminating" in { val postStopLatch = new CountDownLatch(1) @silent - val options = Backoff - .onFailure( - Props(new SlowlyFailingActor(postStopLatch)), - "someChildName", - 1 nanos, - 1 nanos, - 0.0, - maxNrOfRetries = -1) + val options = BackoffOpts + .onFailure(Props(new SlowlyFailingActor(postStopLatch)), "someChildName", 1 nanos, 1 nanos, 0.0) + .withMaxNrOfRetries(-1) .withSupervisorStrategy(OneForOneStrategy(loggingEnabled = false) { case _: TestActor.StoppingException => SupervisorStrategy.Stop }) - @silent val supervisor = system.actorOf(BackoffSupervisor.props(options)) supervisor ! BackoffSupervisor.GetCurrentChild @@ -221,13 +215,12 @@ class BackoffOnRestartSupervisorSpec extends AkkaSpec(""" // withinTimeRange indicates the time range in which maxNrOfRetries will cause the child to // stop. IE: If we restart more than maxNrOfRetries in a time range longer than withinTimeRange // that is acceptable. - @silent - val options = Backoff - .onFailure(TestActor.props(probe.ref), "someChildName", 300.millis, 10.seconds, 0.0, maxNrOfRetries = -1) + val options = BackoffOpts + .onFailure(TestActor.props(probe.ref), "someChildName", 300.millis, 10.seconds, 0.0) + .withMaxNrOfRetries(-1) .withSupervisorStrategy(OneForOneStrategy(withinTimeRange = 1 seconds, maxNrOfRetries = 3) { case _: TestActor.StoppingException => SupervisorStrategy.Stop }) - @silent val supervisor = system.actorOf(BackoffSupervisor.props(options)) probe.expectMsg("STARTED") filterException[TestActor.TestException] { diff --git a/akka-actor-tests/src/test/scala/akka/pattern/BackoffSupervisorSpec.scala b/akka-actor-tests/src/test/scala/akka/pattern/BackoffSupervisorSpec.scala index 48039c90fe..dcccd99b1b 100644 --- a/akka-actor-tests/src/test/scala/akka/pattern/BackoffSupervisorSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/pattern/BackoffSupervisorSpec.scala @@ -7,7 +7,6 @@ package akka.pattern import scala.concurrent.duration._ import scala.util.control.NoStackTrace -import com.github.ghik.silencer.silent import org.scalatest.concurrent.Eventually import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -24,7 +23,7 @@ object BackoffSupervisorSpec { } class Child(probe: ActorRef) extends Actor { - def receive = { + def receive: Receive = { case "boom" => throw new TestException case msg => probe ! msg } @@ -36,7 +35,7 @@ object BackoffSupervisorSpec { } class ManualChild(probe: ActorRef) extends Actor { - def receive = { + def receive: Receive = { case "boom" => throw new TestException case msg => probe ! msg @@ -48,14 +47,13 @@ object BackoffSupervisorSpec { class BackoffSupervisorSpec extends AkkaSpec with ImplicitSender with Eventually { import BackoffSupervisorSpec._ - @silent("deprecated") - def onStopOptions(props: Props = Child.props(testActor), maxNrOfRetries: Int = -1) = - Backoff.onStop(props, "c1", 100.millis, 3.seconds, 0.2, maxNrOfRetries) - @silent("deprecated") - def onFailureOptions(props: Props = Child.props(testActor), maxNrOfRetries: Int = -1) = - Backoff.onFailure(props, "c1", 100.millis, 3.seconds, 0.2, maxNrOfRetries) - @silent("deprecated") - def create(options: BackoffOptions) = system.actorOf(BackoffSupervisor.props(options)) + def onStopOptions(props: Props = Child.props(testActor), maxNrOfRetries: Int = -1): BackoffOnStopOptions = + BackoffOpts.onStop(props, "c1", 100.millis, 3.seconds, 0.2).withMaxNrOfRetries(maxNrOfRetries) + def onFailureOptions(props: Props = Child.props(testActor), maxNrOfRetries: Int = -1): BackoffOnFailureOptions = + BackoffOpts.onFailure(props, "c1", 100.millis, 3.seconds, 0.2).withMaxNrOfRetries(maxNrOfRetries) + + def create(options: BackoffOnStopOptions): ActorRef = system.actorOf(BackoffSupervisor.props(options)) + def create(options: BackoffOnFailureOptions): ActorRef = system.actorOf(BackoffSupervisor.props(options)) "BackoffSupervisor" must { "start child again when it stops when using `Backoff.onStop`" in { @@ -179,10 +177,10 @@ class BackoffSupervisorSpec extends AkkaSpec with ImplicitSender with Eventually "reply to sender if replyWhileStopped is specified" in { filterException[TestException] { - @silent("deprecated") val supervisor = create( - Backoff - .onFailure(Child.props(testActor), "c1", 100.seconds, 300.seconds, 0.2, maxNrOfRetries = -1) + BackoffOpts + .onFailure(Child.props(testActor), "c1", 100.seconds, 300.seconds, 0.2) + .withMaxNrOfRetries(-1) .withReplyWhileStopped("child was stopped")) supervisor ! BackoffSupervisor.GetCurrentChild val c1 = expectMsgType[BackoffSupervisor.CurrentChild].ref.get @@ -203,11 +201,43 @@ class BackoffSupervisorSpec extends AkkaSpec with ImplicitSender with Eventually } } + "use provided actor while stopped and withHandlerWhileStopped is specified" in { + val handler = system.actorOf(Props(new Actor { + override def receive: Receive = { + case "still there?" => + sender() ! "not here!" + } + })) + filterException[TestException] { + val supervisor = create( + BackoffOpts + .onFailure(Child.props(testActor), "c1", 100.seconds, 300.seconds, 0.2) + .withMaxNrOfRetries(-1) + .withHandlerWhileStopped(handler)) + supervisor ! BackoffSupervisor.GetCurrentChild + val c1 = expectMsgType[BackoffSupervisor.CurrentChild].ref.get + watch(c1) + supervisor ! BackoffSupervisor.GetRestartCount + expectMsg(BackoffSupervisor.RestartCount(0)) + + c1 ! "boom" + expectTerminated(c1) + + awaitAssert { + supervisor ! BackoffSupervisor.GetRestartCount + expectMsg(BackoffSupervisor.RestartCount(1)) + } + + supervisor ! "still there?" + expectMsg("not here!") + } + } + "not reply to sender if replyWhileStopped is NOT specified" in { filterException[TestException] { - @silent("deprecated") val supervisor = - create(Backoff.onFailure(Child.props(testActor), "c1", 100.seconds, 300.seconds, 0.2, maxNrOfRetries = -1)) + create( + BackoffOpts.onFailure(Child.props(testActor), "c1", 100.seconds, 300.seconds, 0.2).withMaxNrOfRetries(-1)) supervisor ! BackoffSupervisor.GetCurrentChild val c1 = expectMsgType[BackoffSupervisor.CurrentChild].ref.get watch(c1) @@ -382,7 +412,7 @@ class BackoffSupervisorSpec extends AkkaSpec with ImplicitSender with Eventually c1 ! PoisonPill expectTerminated(c1) // since actor stopped we can expect the two messages to end up in dead letters - EventFilter.warning(pattern = ".*(ping|stop).*", occurrences = 2).intercept { + EventFilter.warning(pattern = ".*(ping|stop).*", occurrences = 1).intercept { supervisor ! "ping" supervisorWatcher.expectNoMessage(20.millis) // supervisor must not terminate diff --git a/akka-actor-tests/src/test/scala/akka/pattern/RetrySpec.scala b/akka-actor-tests/src/test/scala/akka/pattern/RetrySpec.scala index f079edfb9f..cc0fc56d67 100644 --- a/akka-actor-tests/src/test/scala/akka/pattern/RetrySpec.scala +++ b/akka-actor-tests/src/test/scala/akka/pattern/RetrySpec.scala @@ -124,6 +124,23 @@ class RetrySpec extends AkkaSpec with RetrySupport { elapse <= 100 shouldBe true } } + + "handle thrown exceptions in same way as failed Future" in { + @volatile var failCount = 0 + + def attempt() = { + if (failCount < 5) { + failCount += 1 + throw new IllegalStateException(failCount.toString) + } else Future.successful(5) + } + + val retried = retry(() => attempt(), 10, 100 milliseconds) + + within(3 seconds) { + Await.result(retried, remaining) should ===(5) + } + } } } diff --git a/akka-actor-typed-tests/src/test/scala/akka/actor/typed/LocalActorRefProviderLogMessagesSpec.scala b/akka-actor-typed-tests/src/test/scala/akka/actor/typed/LocalActorRefProviderLogMessagesSpec.scala new file mode 100644 index 0000000000..48baf19193 --- /dev/null +++ b/akka-actor-typed-tests/src/test/scala/akka/actor/typed/LocalActorRefProviderLogMessagesSpec.scala @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.actor.typed + +import akka.actor.testkit.typed.scaladsl.{ ActorTestKit, LogCapturing, LoggingTestKit, ScalaTestWithActorTestKit } +import akka.actor.typed.internal.adapter.ActorSystemAdapter +import org.scalatest.wordspec.AnyWordSpecLike + +object LocalActorRefProviderLogMessagesSpec { + val config = """ + akka { + loglevel = DEBUG # test verifies debug + log-dead-letters = on + actor { + debug.unhandled = on + } + } + """ +} + +class LocalActorRefProviderLogMessagesSpec + extends ScalaTestWithActorTestKit(LocalActorRefProviderLogMessagesSpec.config) + with AnyWordSpecLike + with LogCapturing { + + "An LocalActorRefProvider" must { + + "logs on dedicated 'serialization' logger of unknown path" in { + val provider = system.asInstanceOf[ActorSystemAdapter[_]].provider + + LoggingTestKit + .debug("of unknown (invalid) path [dummy/path]") + .withLoggerName("akka.actor.LocalActorRefProvider.Deserialization") + .expect { + provider.resolveActorRef("dummy/path") + } + } + + "logs on dedicated 'serialization' logger when path doesn't match existing actor" in { + val provider = system.asInstanceOf[ActorSystemAdapter[_]].provider + val invalidPath = provider.rootPath / "user" / "invalid" + + LoggingTestKit + .debug("Resolve (deserialization) of path [user/invalid] doesn't match an active actor.") + .withLoggerName("akka.actor.LocalActorRefProvider.Deserialization") + .expect { + provider.resolveActorRef(invalidPath) + } + } + + "logs on dedicated 'serialization' logger when of foreign path" in { + + val otherSystem = ActorTestKit("otherSystem").system.asInstanceOf[ActorSystemAdapter[_]] + val invalidPath = otherSystem.provider.rootPath / "user" / "foo" + + val provider = system.asInstanceOf[ActorSystemAdapter[_]].provider + try { + LoggingTestKit + .debug("Resolve (deserialization) of foreign path [akka://otherSystem/user/foo]") + .withLoggerName("akka.actor.LocalActorRefProvider.Deserialization") + .expect { + provider.resolveActorRef(invalidPath) + } + } finally { + ActorTestKit.shutdown(otherSystem) + } + } + } +} diff --git a/akka-actor/src/main/mima-filters/2.6.5.backwards.excludes/29082-backoff-reply.excludes b/akka-actor/src/main/mima-filters/2.6.5.backwards.excludes/29082-backoff-reply.excludes new file mode 100644 index 0000000000..a653bde9fb --- /dev/null +++ b/akka-actor/src/main/mima-filters/2.6.5.backwards.excludes/29082-backoff-reply.excludes @@ -0,0 +1,27 @@ +# Internals changed +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.pattern.ExtendedBackoffOptions.withHandlerWhileStopped") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.apply") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.apply$default$8") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.apply") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.apply$default$8") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.$default$8") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.apply$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.apply") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.pattern.BackoffOnFailureOptionsImpl.replyWhileStopped") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.copy") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.copy$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnFailureOptionsImpl.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.$default$8") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.apply$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.apply") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.pattern.BackoffOnStopOptionsImpl.replyWhileStopped") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.copy") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.copy$default$8") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.BackoffOnStopOptionsImpl.this") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.internal.BackoffOnRestartSupervisor.this") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.pattern.internal.BackoffOnStopSupervisor.this") +ProblemFilters.exclude[IncompatibleSignatureProblem]("akka.pattern.BackoffOnFailureOptionsImpl.unapply") +ProblemFilters.exclude[IncompatibleSignatureProblem]("akka.pattern.BackoffOnStopOptionsImpl.unapply") +ProblemFilters.exclude[IncompatibleSignatureProblem]("akka.pattern.BackoffOnFailureOptionsImpl.unapply") \ No newline at end of file diff --git a/akka-actor/src/main/scala/akka/actor/ActorRefProvider.scala b/akka-actor/src/main/scala/akka/actor/ActorRefProvider.scala index fcc0c3a850..3eaf85f406 100644 --- a/akka-actor/src/main/scala/akka/actor/ActorRefProvider.scala +++ b/akka-actor/src/main/scala/akka/actor/ActorRefProvider.scala @@ -387,7 +387,14 @@ private[akka] class LocalActorRefProvider private[akka] ( override val rootPath: ActorPath = RootActorPath(Address("akka", _systemName)) private[akka] val log: MarkerLoggingAdapter = - Logging.withMarker(eventStream, getClass.getName + "(" + rootPath.address + ")") + Logging.withMarker(eventStream, getClass) + + /* + * This dedicated logger is used whenever a deserialization failure occurs + * and can therefore be disabled/enabled independently + */ + private[akka] val logDeser: MarkerLoggingAdapter = + Logging.withMarker(eventStream, getClass.getName + ".Deserialization") override val deadLetters: InternalActorRef = _deadLetters @@ -587,14 +594,14 @@ private[akka] class LocalActorRefProvider private[akka] ( def resolveActorRef(path: String): ActorRef = path match { case ActorPathExtractor(address, elems) if address == rootPath.address => resolveActorRef(rootGuardian, elems) case _ => - log.debug("Resolve (deserialization) of unknown (invalid) path [{}], using deadLetters.", path) + logDeser.debug("Resolve (deserialization) of unknown (invalid) path [{}], using deadLetters.", path) deadLetters } def resolveActorRef(path: ActorPath): ActorRef = { if (path.root == rootPath) resolveActorRef(rootGuardian, path.elements) else { - log.debug( + logDeser.debug( "Resolve (deserialization) of foreign path [{}] doesn't match root path [{}], using deadLetters.", path, rootPath) @@ -607,13 +614,13 @@ private[akka] class LocalActorRefProvider private[akka] ( */ private[akka] def resolveActorRef(ref: InternalActorRef, pathElements: Iterable[String]): InternalActorRef = if (pathElements.isEmpty) { - log.debug("Resolve (deserialization) of empty path doesn't match an active actor, using deadLetters.") + logDeser.debug("Resolve (deserialization) of empty path doesn't match an active actor, using deadLetters.") deadLetters } else ref.getChild(pathElements.iterator) match { case Nobody => if (log.isDebugEnabled) - log.debug( + logDeser.debug( "Resolve (deserialization) of path [{}] doesn't match an active actor. " + "It has probably been stopped, using deadLetters.", pathElements.mkString("/")) diff --git a/akka-actor/src/main/scala/akka/actor/dungeon/DeathWatch.scala b/akka-actor/src/main/scala/akka/actor/dungeon/DeathWatch.scala index cd15ce390f..468e5865d9 100644 --- a/akka-actor/src/main/scala/akka/actor/dungeon/DeathWatch.scala +++ b/akka-actor/src/main/scala/akka/actor/dungeon/DeathWatch.scala @@ -63,7 +63,15 @@ private[akka] trait DeathWatch { this: ActorCell => protected def receivedTerminated(t: Terminated): Unit = terminatedQueued.get(t.actor).foreach { optionalMessage => terminatedQueued -= t.actor // here we know that it is the SAME ref which was put in - receiveMessage(optionalMessage.getOrElse(t)) + optionalMessage match { + case Some(customTermination) => + // needed for stashing of custom watch messages to work (or stash will stash the Terminated message instead) + currentMessage = currentMessage.copy(message = customTermination) + receiveMessage(customTermination) + + case None => + receiveMessage(t) + } } /** diff --git a/akka-actor/src/main/scala/akka/pattern/Backoff.scala b/akka-actor/src/main/scala/akka/pattern/Backoff.scala index 45f628959c..6cc78aeed6 100644 --- a/akka-actor/src/main/scala/akka/pattern/Backoff.scala +++ b/akka-actor/src/main/scala/akka/pattern/Backoff.scala @@ -617,7 +617,7 @@ private final case class BackoffOptionsImpl( backoffReset, randomFactor, supervisorStrategy, - replyWhileStopped)) + replyWhileStopped.map(msg => ReplyWith(msg)).getOrElse(ForwardDeathLetters))) //onStop method in companion object case StopImpliesFailure => Props( @@ -629,7 +629,7 @@ private final case class BackoffOptionsImpl( backoffReset, randomFactor, supervisorStrategy, - replyWhileStopped, + replyWhileStopped.map(msg => ReplyWith(msg)).getOrElse(ForwardDeathLetters), finalStopMessage)) } } diff --git a/akka-actor/src/main/scala/akka/pattern/BackoffOptions.scala b/akka-actor/src/main/scala/akka/pattern/BackoffOptions.scala index 6e893978ac..8b2c5e61e8 100644 --- a/akka-actor/src/main/scala/akka/pattern/BackoffOptions.scala +++ b/akka-actor/src/main/scala/akka/pattern/BackoffOptions.scala @@ -5,9 +5,8 @@ package akka.pattern import scala.concurrent.duration.{ Duration, FiniteDuration } - -import akka.actor.{ OneForOneStrategy, Props, SupervisorStrategy } -import akka.annotation.DoNotInherit +import akka.actor.{ ActorRef, OneForOneStrategy, Props, SupervisorStrategy } +import akka.annotation.{ DoNotInherit, InternalApi } import akka.pattern.internal.{ BackoffOnRestartSupervisor, BackoffOnStopSupervisor } import akka.util.JavaDurationConverters._ @@ -299,6 +298,15 @@ private[akka] sealed trait ExtendedBackoffOptions[T <: ExtendedBackoffOptions[T] */ def withReplyWhileStopped(replyWhileStopped: Any): T + /** + * Returns a new BackoffOptions with a custom handler for messages that the supervisor receives while its child is stopped. + * By default, a message received while the child is stopped is forwarded to `deadLetters`. + * Essentially, this handler replaces `deadLetters` allowing to implement custom handling instead of a static reply. + * + * @param handler PartialFunction of the received message and sender + */ + def withHandlerWhileStopped(handler: ActorRef): T + /** * Returns the props to create the back-off supervisor. */ @@ -334,7 +342,7 @@ private final case class BackoffOnStopOptionsImpl[T]( randomFactor: Double, reset: Option[BackoffReset] = None, supervisorStrategy: OneForOneStrategy = OneForOneStrategy()(SupervisorStrategy.defaultStrategy.decider), - replyWhileStopped: Option[Any] = None, + handlingWhileStopped: HandlingWhileStopped = ForwardDeathLetters, finalStopMessage: Option[Any => Boolean] = None) extends BackoffOnStopOptions { @@ -344,7 +352,9 @@ private final case class BackoffOnStopOptionsImpl[T]( def withAutoReset(resetBackoff: FiniteDuration) = copy(reset = Some(AutoReset(resetBackoff))) def withManualReset = copy(reset = Some(ManualReset)) def withSupervisorStrategy(supervisorStrategy: OneForOneStrategy) = copy(supervisorStrategy = supervisorStrategy) - def withReplyWhileStopped(replyWhileStopped: Any) = copy(replyWhileStopped = Some(replyWhileStopped)) + def withReplyWhileStopped(replyWhileStopped: Any) = copy(handlingWhileStopped = ReplyWith(replyWhileStopped)) + def withHandlerWhileStopped(handlerWhileStopped: ActorRef) = + copy(handlingWhileStopped = ForwardTo(handlerWhileStopped)) def withMaxNrOfRetries(maxNrOfRetries: Int) = copy(supervisorStrategy = supervisorStrategy.withMaxNrOfRetries(maxNrOfRetries)) @@ -374,7 +384,7 @@ private final case class BackoffOnStopOptionsImpl[T]( backoffReset, randomFactor, supervisorStrategy, - replyWhileStopped, + handlingWhileStopped, finalStopMessage)) } } @@ -387,7 +397,7 @@ private final case class BackoffOnFailureOptionsImpl[T]( randomFactor: Double, reset: Option[BackoffReset] = None, supervisorStrategy: OneForOneStrategy = OneForOneStrategy()(SupervisorStrategy.defaultStrategy.decider), - replyWhileStopped: Option[Any] = None) + handlingWhileStopped: HandlingWhileStopped = ForwardDeathLetters) extends BackoffOnFailureOptions { private val backoffReset = reset.getOrElse(AutoReset(minBackoff)) @@ -396,7 +406,9 @@ private final case class BackoffOnFailureOptionsImpl[T]( def withAutoReset(resetBackoff: FiniteDuration) = copy(reset = Some(AutoReset(resetBackoff))) def withManualReset = copy(reset = Some(ManualReset)) def withSupervisorStrategy(supervisorStrategy: OneForOneStrategy) = copy(supervisorStrategy = supervisorStrategy) - def withReplyWhileStopped(replyWhileStopped: Any) = copy(replyWhileStopped = Some(replyWhileStopped)) + def withReplyWhileStopped(replyWhileStopped: Any) = copy(handlingWhileStopped = ReplyWith(replyWhileStopped)) + def withHandlerWhileStopped(handlerWhileStopped: ActorRef) = + copy(handlingWhileStopped = ForwardTo(handlerWhileStopped)) def withMaxNrOfRetries(maxNrOfRetries: Int) = copy(supervisorStrategy = supervisorStrategy.withMaxNrOfRetries(maxNrOfRetries)) @@ -419,10 +431,17 @@ private final case class BackoffOnFailureOptionsImpl[T]( backoffReset, randomFactor, supervisorStrategy, - replyWhileStopped)) + handlingWhileStopped)) } } +@InternalApi private[akka] sealed trait BackoffReset private[akka] case object ManualReset extends BackoffReset private[akka] final case class AutoReset(resetBackoff: FiniteDuration) extends BackoffReset + +@InternalApi +private[akka] sealed trait HandlingWhileStopped +private[akka] case object ForwardDeathLetters extends HandlingWhileStopped +private[akka] case class ForwardTo(handler: ActorRef) extends HandlingWhileStopped +private[akka] case class ReplyWith(msg: Any) extends HandlingWhileStopped diff --git a/akka-actor/src/main/scala/akka/pattern/BackoffSupervisor.scala b/akka-actor/src/main/scala/akka/pattern/BackoffSupervisor.scala index e85fe90dbe..112406f47a 100644 --- a/akka-actor/src/main/scala/akka/pattern/BackoffSupervisor.scala +++ b/akka-actor/src/main/scala/akka/pattern/BackoffSupervisor.scala @@ -184,7 +184,7 @@ object BackoffSupervisor { AutoReset(minBackoff), randomFactor, strategy, - None, + ForwardDeathLetters, None)) } @@ -341,7 +341,7 @@ final class BackoffSupervisor @deprecated("Use `BackoffSupervisor.props` method reset, randomFactor, strategy, - replyWhileStopped, + replyWhileStopped.map(msg => ReplyWith(msg)).getOrElse(ForwardDeathLetters), finalStopMessage) { // for binary compatibility with 2.5.18 diff --git a/akka-actor/src/main/scala/akka/pattern/RetrySupport.scala b/akka-actor/src/main/scala/akka/pattern/RetrySupport.scala index 0380f2ba0b..7c9a111481 100644 --- a/akka-actor/src/main/scala/akka/pattern/RetrySupport.scala +++ b/akka-actor/src/main/scala/akka/pattern/RetrySupport.scala @@ -153,39 +153,44 @@ object RetrySupport extends RetrySupport { maxAttempts: Int, delayFunction: Int => Option[FiniteDuration], attempted: Int)(implicit ec: ExecutionContext, scheduler: Scheduler): Future[T] = { - try { - require(maxAttempts >= 0, "Parameter maxAttempts must >= 0.") - require(attempt != null, "Parameter attempt should not be null.") - if (maxAttempts - attempted > 0) { - val result = attempt() - if (result eq null) - result - else { - val nextAttempt = attempted + 1 - result.recoverWith { - case NonFatal(_) => - delayFunction(nextAttempt) match { - case Some(delay) => - if (delay.length < 1) - retry(attempt, maxAttempts, delayFunction, nextAttempt) - else - after(delay, scheduler) { - retry(attempt, maxAttempts, delayFunction, nextAttempt) - } - case None => - retry(attempt, maxAttempts, delayFunction, nextAttempt) - case _ => - Future.failed(new IllegalArgumentException("The delayFunction of retry should not return null.")) - } - } - } - - } else { + def tryAttempt(): Future[T] = { + try { attempt() + } catch { + case NonFatal(exc) => Future.failed(exc) // in case the `attempt` function throws } - } catch { - case NonFatal(error) => Future.failed(error) + } + + require(maxAttempts >= 0, "Parameter maxAttempts must >= 0.") + require(attempt != null, "Parameter attempt should not be null.") + if (maxAttempts - attempted > 0) { + val result = tryAttempt() + if (result eq null) + result + else { + val nextAttempt = attempted + 1 + result.recoverWith { + case NonFatal(_) => + delayFunction(nextAttempt) match { + case Some(delay) => + if (delay.length < 1) + retry(attempt, maxAttempts, delayFunction, nextAttempt) + else + after(delay, scheduler) { + retry(attempt, maxAttempts, delayFunction, nextAttempt) + } + case None => + retry(attempt, maxAttempts, delayFunction, nextAttempt) + case _ => + Future.failed(new IllegalArgumentException("The delayFunction of retry should not return null.")) + } + + } + } + + } else { + tryAttempt() } } } diff --git a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnRestartSupervisor.scala b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnRestartSupervisor.scala index 3f799dc361..8fbd330e79 100644 --- a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnRestartSupervisor.scala +++ b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnRestartSupervisor.scala @@ -9,7 +9,15 @@ import scala.concurrent.duration._ import akka.actor.{ OneForOneStrategy, _ } import akka.actor.SupervisorStrategy._ import akka.annotation.InternalApi -import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } +import akka.pattern.{ + BackoffReset, + BackoffSupervisor, + ForwardDeathLetters, + ForwardTo, + HandleBackoff, + HandlingWhileStopped, + ReplyWith +} /** * INTERNAL API @@ -26,7 +34,7 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } val reset: BackoffReset, randomFactor: Double, strategy: OneForOneStrategy, - replyWhileStopped: Option[Any]) + handlingWhileStopped: HandlingWhileStopped) extends Actor with HandleBackoff with ActorLogging { @@ -34,7 +42,7 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } import BackoffSupervisor._ import context._ - override val supervisorStrategy = + override val supervisorStrategy: OneForOneStrategy = OneForOneStrategy(strategy.maxNrOfRetries, strategy.withinTimeRange, strategy.loggingEnabled) { case ex => val defaultDirective: Directive = @@ -94,9 +102,10 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } case Some(c) => c.forward(msg) case None => - replyWhileStopped match { - case None => context.system.deadLetters.forward(msg) - case Some(r) => sender() ! r + handlingWhileStopped match { + case ForwardDeathLetters => context.system.deadLetters.forward(msg) + case ForwardTo(h) => h.forward(msg) + case ReplyWith(r) => sender() ! r } } } diff --git a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala index af94d4fa57..a2069de45b 100644 --- a/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala +++ b/akka-actor/src/main/scala/akka/pattern/internal/BackoffOnStopSupervisor.scala @@ -9,7 +9,15 @@ import scala.concurrent.duration.FiniteDuration import akka.actor.{ Actor, ActorLogging, OneForOneStrategy, Props, SupervisorStrategy, Terminated } import akka.actor.SupervisorStrategy.{ Directive, Escalate } import akka.annotation.InternalApi -import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } +import akka.pattern.{ + BackoffReset, + BackoffSupervisor, + ForwardDeathLetters, + ForwardTo, + HandleBackoff, + HandlingWhileStopped, + ReplyWith +} /** * INTERNAL API @@ -26,7 +34,7 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } val reset: BackoffReset, randomFactor: Double, strategy: SupervisorStrategy, - replyWhileStopped: Option[Any], + handlingWhileStopped: HandlingWhileStopped, finalStopMessage: Option[Any => Boolean]) extends Actor with HandleBackoff @@ -35,7 +43,7 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } import BackoffSupervisor._ import context.dispatcher - override val supervisorStrategy = strategy match { + override val supervisorStrategy: SupervisorStrategy = strategy match { case oneForOne: OneForOneStrategy => OneForOneStrategy(oneForOne.maxNrOfRetries, oneForOne.withinTimeRange, oneForOne.loggingEnabled) { case ex => @@ -84,13 +92,14 @@ import akka.pattern.{ BackoffReset, BackoffSupervisor, HandleBackoff } case None => } case None => - replyWhileStopped match { - case Some(r) => sender() ! r - case None => context.system.deadLetters.forward(msg) - } finalStopMessage match { case Some(fsm) if fsm(msg) => context.stop(self) - case _ => + case _ => + handlingWhileStopped match { + case ForwardDeathLetters => context.system.deadLetters.forward(msg) + case ForwardTo(h) => h.forward(msg) + case ReplyWith(r) => sender() ! r + } } } } diff --git a/akka-cluster-sharding-typed/src/main/scala/akka/cluster/sharding/typed/internal/ShardedDaemonProcessImpl.scala b/akka-cluster-sharding-typed/src/main/scala/akka/cluster/sharding/typed/internal/ShardedDaemonProcessImpl.scala index 5946f3d3b4..6459c4b3c3 100644 --- a/akka-cluster-sharding-typed/src/main/scala/akka/cluster/sharding/typed/internal/ShardedDaemonProcessImpl.scala +++ b/akka-cluster-sharding-typed/src/main/scala/akka/cluster/sharding/typed/internal/ShardedDaemonProcessImpl.scala @@ -16,15 +16,13 @@ import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.LoggerOps import akka.annotation.InternalApi import akka.cluster.sharding.ShardRegion.EntityId -import akka.cluster.sharding.typed.{ - javadsl, - scaladsl, - ClusterShardingSettings, - ShardedDaemonProcessSettings, - ShardingEnvelope, - ShardingMessageExtractor -} +import akka.cluster.sharding.typed.ClusterShardingSettings import akka.cluster.sharding.typed.ClusterShardingSettings.{ RememberEntitiesStoreModeDData, StateStoreModeDData } +import akka.cluster.sharding.typed.ShardedDaemonProcessSettings +import akka.cluster.sharding.typed.ShardingEnvelope +import akka.cluster.sharding.typed.ShardingMessageExtractor +import akka.cluster.sharding.typed.javadsl +import akka.cluster.sharding.typed.scaladsl import akka.cluster.sharding.typed.scaladsl.ClusterSharding import akka.cluster.sharding.typed.scaladsl.Entity import akka.cluster.sharding.typed.scaladsl.EntityTypeKey diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala index a90d0cf96e..3557e69ac5 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala @@ -14,10 +14,15 @@ import akka.actor.DeadLetterSuppression import akka.annotation.InternalApi import akka.cluster.Cluster import akka.cluster.ClusterEvent -import akka.cluster.ClusterEvent.{ ClusterShuttingDown, InitialStateAsEvents } +import akka.cluster.ClusterEvent._ +import akka.cluster.ddata.GSet +import akka.cluster.ddata.GSetKey +import akka.cluster.ddata.Key +import akka.cluster.ddata.LWWRegister +import akka.cluster.ddata.LWWRegisterKey +import akka.cluster.ddata.ReplicatedData import akka.cluster.ddata.Replicator._ -import akka.cluster.ddata.{ LWWRegister, LWWRegisterKey, SelfUniqueAddress } -import akka.cluster.sharding.DDataShardCoordinator.RememberEntitiesLoadTimeout +import akka.cluster.ddata.SelfUniqueAddress import akka.cluster.sharding.ShardRegion.ShardId import akka.cluster.sharding.internal.EventSourcedRememberShards.MigrationMarker import akka.cluster.sharding.internal.{ @@ -617,7 +622,7 @@ abstract class ShardCoordinator( case GetShardHome(shard) => if (!handleGetShardHome(shard)) { // location not know, yet - val activeRegions = state.regions -- gracefulShutdownInProgress + val activeRegions = (state.regions -- gracefulShutdownInProgress) -- regionTerminationInProgress if (activeRegions.nonEmpty) { val getShardHomeSender = sender() val regionFuture = allocationStrategy.allocateShard(getShardHomeSender, shard, activeRegions) @@ -929,7 +934,8 @@ abstract class ShardCoordinator( state.shards.get(shard) match { case Some(ref) => getShardHomeSender ! ShardHome(shard, ref) case None => - if (state.regions.contains(region) && !gracefulShutdownInProgress.contains(region)) { + if (state.regions.contains(region) && !gracefulShutdownInProgress.contains(region) && !regionTerminationInProgress + .contains(region)) { update(ShardHomeAllocated(shard, region)) { evt => state = state.updated(evt) log.debug( diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala index 302c004206..4572cd71a6 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala @@ -32,16 +32,6 @@ import akka.util.MessageBufferMap import akka.util.PrettyDuration import akka.util.Timeout -import scala.annotation.tailrec -import scala.collection.immutable -import scala.concurrent.Future -import scala.concurrent.Promise -import scala.concurrent.duration._ -import scala.reflect.ClassTag -import scala.runtime.AbstractFunction1 -import scala.util.Failure -import scala.util.Success - /** * @see [[ClusterSharding$ ClusterSharding extension]] */ diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala index 6bb328e5d7..bcd285558c 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/protobuf/ClusterShardingMessageSerializer.scala @@ -13,6 +13,7 @@ import scala.annotation.tailrec import scala.concurrent.duration._ import akka.util.ccompat.JavaConverters._ import scala.collection.immutable +import scala.concurrent.duration._ import akka.actor.ActorRef import akka.actor.Address @@ -21,20 +22,17 @@ import akka.cluster.sharding.Shard import akka.cluster.sharding.ShardCoordinator import akka.cluster.sharding.ShardRegion._ import akka.cluster.sharding.protobuf.msg.{ ClusterShardingMessages => sm } +import akka.cluster.sharding.protobuf.msg.ClusterShardingMessages import akka.cluster.sharding.internal.EventSourcedRememberShards.{ MigrationMarker, State => RememberShardsState } import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStore.{ State => EntityState } -import akka.cluster.sharding.protobuf.msg.ClusterShardingMessages +import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStore.{ EntitiesStarted, EntitiesStopped } import akka.protobufv3.internal.MessageLite import akka.serialization.BaseSerializer import akka.serialization.Serialization import akka.serialization.SerializerWithStringManifest import akka.util.ccompat._ -import java.io.NotSerializableException -import akka.actor.Address -import akka.cluster.sharding.ShardRegion._ -import akka.cluster.sharding.internal.EventSourcedRememberEntitiesStore.{ EntitiesStarted, EntitiesStopped } -import akka.cluster.sharding.protobuf.msg.ClusterShardingMessages +import akka.util.ccompat.JavaConverters._ /** * INTERNAL API: Protobuf serializer of ClusterSharding messages. diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowning2Spec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowning2Spec.scala new file mode 100644 index 0000000000..bea4817d86 --- /dev/null +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowning2Spec.scala @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import scala.concurrent.duration._ + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.cluster.MemberStatus +import akka.remote.transport.ThrottlerTransportAdapter.Direction +import akka.serialization.jackson.CborSerializable +import akka.testkit._ +import akka.util.ccompat._ + +@ccompatUsedUntil213 +object ClusterShardCoordinatorDowning2Spec { + case class Ping(id: String) extends CborSerializable + + class Entity extends Actor { + def receive = { + case Ping(_) => sender() ! self + } + } + + case object GetLocations extends CborSerializable + case class Locations(locations: Map[String, ActorRef]) extends CborSerializable + + class ShardLocations extends Actor { + var locations: Locations = _ + def receive = { + case GetLocations => sender() ! locations + case l: Locations => locations = l + } + } + + val extractEntityId: ShardRegion.ExtractEntityId = { + case m @ Ping(id) => (id, m) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case Ping(id: String) => id.charAt(0).toString + } +} + +abstract class ClusterShardCoordinatorDowning2SpecConfig(mode: String) + extends MultiNodeClusterShardingConfig( + mode, + loglevel = "INFO", + additionalConfig = """ + akka.cluster.sharding.rebalance-interval = 120 s + # setting down-removal-margin, for testing of issue #29131 + akka.cluster.down-removal-margin = 3 s + akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 3s + """) { + val first = role("first") + val second = role("second") + + testTransport(on = true) + +} + +object PersistentClusterShardCoordinatorDowning2SpecConfig + extends ClusterShardCoordinatorDowning2SpecConfig(ClusterShardingSettings.StateStoreModePersistence) +object DDataClusterShardCoordinatorDowning2SpecConfig + extends ClusterShardCoordinatorDowning2SpecConfig(ClusterShardingSettings.StateStoreModeDData) + +class PersistentClusterShardCoordinatorDowning2Spec + extends ClusterShardCoordinatorDowning2Spec(PersistentClusterShardCoordinatorDowning2SpecConfig) +class DDataClusterShardCoordinatorDowning2Spec + extends ClusterShardCoordinatorDowning2Spec(DDataClusterShardCoordinatorDowning2SpecConfig) + +class PersistentClusterShardCoordinatorDowning2MultiJvmNode1 extends PersistentClusterShardCoordinatorDowning2Spec +class PersistentClusterShardCoordinatorDowning2MultiJvmNode2 extends PersistentClusterShardCoordinatorDowning2Spec + +class DDataClusterShardCoordinatorDowning2MultiJvmNode1 extends DDataClusterShardCoordinatorDowning2Spec +class DDataClusterShardCoordinatorDowning2MultiJvmNode2 extends DDataClusterShardCoordinatorDowning2Spec + +abstract class ClusterShardCoordinatorDowning2Spec(multiNodeConfig: ClusterShardCoordinatorDowning2SpecConfig) + extends MultiNodeClusterShardingSpec(multiNodeConfig) + with ImplicitSender { + import multiNodeConfig._ + + import ClusterShardCoordinatorDowning2Spec._ + + def startSharding(): Unit = { + startSharding( + system, + typeName = "Entity", + entityProps = Props[Entity](), + extractEntityId = extractEntityId, + extractShardId = extractShardId) + } + + lazy val region = ClusterSharding(system).shardRegion("Entity") + + s"Cluster sharding ($mode) with down member, scenario 2" must { + + "join cluster" in within(20.seconds) { + startPersistenceIfNotDdataMode(startOn = first, setStoreOn = Seq(first, second)) + + join(first, first, onJoinedRunOnFrom = startSharding()) + join(second, first, onJoinedRunOnFrom = startSharding(), assertNodeUp = false) + + // all Up, everywhere before continuing + runOn(first, second) { + awaitAssert { + cluster.state.members.size should ===(2) + cluster.state.members.unsorted.map(_.status) should ===(Set(MemberStatus.Up)) + } + } + + enterBarrier("after-2") + } + + "initialize shards" in { + runOn(first) { + val shardLocations = system.actorOf(Props[ShardLocations](), "shardLocations") + val locations = (for (n <- 1 to 4) yield { + val id = n.toString + region ! Ping(id) + id -> expectMsgType[ActorRef] + }).toMap + shardLocations ! Locations(locations) + system.log.debug("Original locations: {}", locations) + } + enterBarrier("after-3") + } + + "recover after downing other node (not coordinator)" in within(20.seconds) { + val secondAddress = address(second) + + runOn(first) { + testConductor.blackhole(first, second, Direction.Both).await + } + + Thread.sleep(3000) + + runOn(first) { + cluster.down(second) + awaitAssert { + cluster.state.members.size should ===(1) + } + + // start a few more new shards, could be allocated to second but should notice that it's terminated + val additionalLocations = + awaitAssert { + val probe = TestProbe() + (for (n <- 5 to 8) yield { + val id = n.toString + region.tell(Ping(id), probe.ref) + id -> probe.expectMsgType[ActorRef](1.second) + }).toMap + } + system.log.debug("Additional locations: {}", additionalLocations) + + system.actorSelection(node(first) / "user" / "shardLocations") ! GetLocations + val Locations(originalLocations) = expectMsgType[Locations] + + awaitAssert { + val probe = TestProbe() + (originalLocations ++ additionalLocations).foreach { + case (id, ref) => + region.tell(Ping(id), probe.ref) + if (ref.path.address == secondAddress) { + val newRef = probe.expectMsgType[ActorRef](1.second) + newRef should not be (ref) + system.log.debug("Moved [{}] from [{}] to [{}]", id, ref, newRef) + } else + probe.expectMsg(1.second, ref) // should not move + } + } + } + + enterBarrier("after-4") + } + + } +} diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala new file mode 100644 index 0000000000..e50cfcf511 --- /dev/null +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import scala.concurrent.duration._ + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.cluster.MemberStatus +import akka.remote.transport.ThrottlerTransportAdapter.Direction +import akka.serialization.jackson.CborSerializable +import akka.testkit._ +import akka.util.ccompat._ + +@ccompatUsedUntil213 +object ClusterShardCoordinatorDowningSpec { + case class Ping(id: String) extends CborSerializable + + class Entity extends Actor { + def receive = { + case Ping(_) => sender() ! self + } + } + + case object GetLocations extends CborSerializable + case class Locations(locations: Map[String, ActorRef]) extends CborSerializable + + class ShardLocations extends Actor { + var locations: Locations = _ + def receive = { + case GetLocations => sender() ! locations + case l: Locations => locations = l + } + } + + val extractEntityId: ShardRegion.ExtractEntityId = { + case m @ Ping(id) => (id, m) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case Ping(id: String) => id.charAt(0).toString + } +} + +abstract class ClusterShardCoordinatorDowningSpecConfig(mode: String) + extends MultiNodeClusterShardingConfig( + mode, + loglevel = "INFO", + additionalConfig = """ + akka.cluster.sharding.rebalance-interval = 120 s + # setting down-removal-margin, for testing of issue #29131 + akka.cluster.down-removal-margin = 3 s + akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 3s + """) { + val controller = role("controller") + val first = role("first") + val second = role("second") + + testTransport(on = true) + +} + +object PersistentClusterShardCoordinatorDowningSpecConfig + extends ClusterShardCoordinatorDowningSpecConfig(ClusterShardingSettings.StateStoreModePersistence) +object DDataClusterShardCoordinatorDowningSpecConfig + extends ClusterShardCoordinatorDowningSpecConfig(ClusterShardingSettings.StateStoreModeDData) + +class PersistentClusterShardCoordinatorDowningSpec + extends ClusterShardCoordinatorDowningSpec(PersistentClusterShardCoordinatorDowningSpecConfig) +class DDataClusterShardCoordinatorDowningSpec + extends ClusterShardCoordinatorDowningSpec(DDataClusterShardCoordinatorDowningSpecConfig) + +class PersistentClusterShardCoordinatorDowningMultiJvmNode1 extends PersistentClusterShardCoordinatorDowningSpec +class PersistentClusterShardCoordinatorDowningMultiJvmNode2 extends PersistentClusterShardCoordinatorDowningSpec +class PersistentClusterShardCoordinatorDowningMultiJvmNode3 extends PersistentClusterShardCoordinatorDowningSpec + +class DDataClusterShardCoordinatorDowningMultiJvmNode1 extends DDataClusterShardCoordinatorDowningSpec +class DDataClusterShardCoordinatorDowningMultiJvmNode2 extends DDataClusterShardCoordinatorDowningSpec +class DDataClusterShardCoordinatorDowningMultiJvmNode3 extends DDataClusterShardCoordinatorDowningSpec + +abstract class ClusterShardCoordinatorDowningSpec(multiNodeConfig: ClusterShardCoordinatorDowningSpecConfig) + extends MultiNodeClusterShardingSpec(multiNodeConfig) + with ImplicitSender { + import multiNodeConfig._ + + import ClusterShardCoordinatorDowningSpec._ + + def startSharding(): Unit = { + startSharding( + system, + typeName = "Entity", + entityProps = Props[Entity](), + extractEntityId = extractEntityId, + extractShardId = extractShardId) + } + + lazy val region = ClusterSharding(system).shardRegion("Entity") + + s"Cluster sharding ($mode) with down member, scenario 1" must { + + "join cluster" in within(20.seconds) { + startPersistenceIfNotDdataMode(startOn = controller, setStoreOn = Seq(first, second)) + + join(first, first, onJoinedRunOnFrom = startSharding()) + join(second, first, onJoinedRunOnFrom = startSharding(), assertNodeUp = false) + + // all Up, everywhere before continuing + runOn(first, second) { + awaitAssert { + cluster.state.members.size should ===(2) + cluster.state.members.unsorted.map(_.status) should ===(Set(MemberStatus.Up)) + } + } + + enterBarrier("after-2") + } + + "initialize shards" in { + runOn(first) { + val shardLocations = system.actorOf(Props[ShardLocations](), "shardLocations") + val locations = (for (n <- 1 to 4) yield { + val id = n.toString + region ! Ping(id) + id -> expectMsgType[ActorRef] + }).toMap + shardLocations ! Locations(locations) + system.log.debug("Original locations: {}", locations) + } + enterBarrier("after-3") + } + + "recover after downing coordinator node" in within(20.seconds) { + val firstAddress = address(first) + system.actorSelection(node(first) / "user" / "shardLocations") ! GetLocations + val Locations(originalLocations) = expectMsgType[Locations] + + runOn(controller) { + testConductor.blackhole(first, second, Direction.Both).await + } + + Thread.sleep(3000) + + runOn(second) { + cluster.down(first) + awaitAssert { + cluster.state.members.size should ===(1) + } + + // start a few more new shards, could be allocated to first but should notice that it's terminated + val additionalLocations = + awaitAssert { + val probe = TestProbe() + (for (n <- 5 to 8) yield { + val id = n.toString + region.tell(Ping(id), probe.ref) + id -> probe.expectMsgType[ActorRef](1.second) + }).toMap + } + system.log.debug("Additional locations: {}", additionalLocations) + + awaitAssert { + val probe = TestProbe() + (originalLocations ++ additionalLocations).foreach { + case (id, ref) => + region.tell(Ping(id), probe.ref) + if (ref.path.address == firstAddress) { + val newRef = probe.expectMsgType[ActorRef](1.second) + newRef should not be (ref) + system.log.debug("Moved [{}] from [{}] to [{}]", id, ref, newRef) + } else + probe.expectMsg(1.second, ref) // should not move + } + } + } + + enterBarrier("after-4") + } + + } +} diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingRememberEntitiesPerfSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingRememberEntitiesPerfSpec.scala new file mode 100644 index 0000000000..e69de29bb2 diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala index 5d9070356c..319c7e8cdc 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ClusterShardingLeaseSpec.scala @@ -3,25 +3,16 @@ */ package akka.cluster.sharding -import akka.actor.Props -import akka.cluster.{ Cluster, MemberStatus } -import akka.testkit.TestActors.EchoActor -import akka.testkit.WithLogCapturing -import akka.testkit.{ AkkaSpec, ImplicitSender } -import com.typesafe.config.{ Config, ConfigFactory } - import scala.concurrent.Future import scala.concurrent.duration._ import scala.util.Success import scala.util.control.NoStackTrace - import com.typesafe.config.{ Config, ConfigFactory } - import akka.actor.Props import akka.cluster.{ Cluster, MemberStatus } import akka.coordination.lease.TestLease import akka.coordination.lease.TestLeaseExt -import akka.testkit.{ AkkaSpec, ImplicitSender } +import akka.testkit.{ AkkaSpec, ImplicitSender, WithLogCapturing } import akka.testkit.TestActors.EchoActor object ClusterShardingLeaseSpec { diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala index 0d8ac63417..8083a5d010 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ConstantRateEntityRecoveryStrategySpec.scala @@ -4,13 +4,13 @@ package akka.cluster.sharding -import akka.cluster.sharding.ShardRegion.EntityId -import akka.cluster.sharding.internal.EntityRecoveryStrategy -import akka.testkit.{ AkkaSpec, TimingTest } - import scala.concurrent.{ Await, Future } import scala.concurrent.duration._ +import akka.cluster.sharding.internal.EntityRecoveryStrategy +import akka.cluster.sharding.ShardRegion.EntityId +import akka.testkit.{ AkkaSpec, TimingTest } + class ConstantRateEntityRecoveryStrategySpec extends AkkaSpec { val strategy = EntityRecoveryStrategy.constantStrategy(system, 1.second, 2) diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala index b71e5f64d7..eebcee64ab 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/JoinConfigCompatCheckShardingSpec.scala @@ -13,10 +13,6 @@ import akka.actor.ActorSystem import akka.cluster.{ Cluster, ClusterReadView } import akka.testkit.WithLogCapturing import akka.testkit.{ AkkaSpec, LongRunningTest } -import com.typesafe.config.{ Config, ConfigFactory } - -import scala.concurrent.duration._ -import scala.collection.{ immutable => im } class JoinConfigCompatCheckShardingSpec extends AkkaSpec() with WithLogCapturing { diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/PersistentShardSpec.scala new file mode 100644 index 0000000000..e69de29bb2 diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala index 0dd0c7357f..9cd30e7f7a 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ProxyShardingSpec.scala @@ -8,8 +8,6 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.concurrent.duration.FiniteDuration -import scala.concurrent.duration.FiniteDuration - import akka.actor.ActorRef import akka.testkit.AkkaSpec import akka.testkit.TestActors diff --git a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala index 8af5fcc363..d86ed62667 100644 --- a/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala +++ b/akka-cluster-sharding/src/test/scala/akka/cluster/sharding/ShardRegionSpec.scala @@ -8,14 +8,10 @@ import java.io.File import com.typesafe.config.ConfigFactory import org.apache.commons.io.FileUtils - import akka.actor.{ Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Props } import akka.cluster.{ Cluster, MemberStatus } import akka.cluster.ClusterEvent.CurrentClusterState -import akka.cluster.{ Cluster, MemberStatus } -import akka.testkit.TestEvent.Mute -import akka.testkit.WithLogCapturing -import akka.testkit.{ AkkaSpec, DeadLettersFilter, TestProbe } +import akka.testkit.{ AkkaSpec, DeadLettersFilter, TestProbe, WithLogCapturing } import akka.testkit.TestEvent.Mute object ShardRegionSpec { diff --git a/akka-cluster/src/main/scala/akka/cluster/ClusterLogMarker.scala b/akka-cluster/src/main/scala/akka/cluster/ClusterLogMarker.scala index c904589c3d..1a0c3ad784 100644 --- a/akka-cluster/src/main/scala/akka/cluster/ClusterLogMarker.scala +++ b/akka-cluster/src/main/scala/akka/cluster/ClusterLogMarker.scala @@ -7,6 +7,7 @@ package akka.cluster import akka.actor.Address import akka.annotation.ApiMayChange import akka.annotation.InternalApi +import akka.cluster.sbr.DowningStrategy import akka.event.LogMarker /** @@ -22,6 +23,7 @@ object ClusterLogMarker { */ @InternalApi private[akka] object Properties { val MemberStatus = "akkaMemberStatus" + val SbrDecision = "akkaSbrDecision" } /** @@ -91,4 +93,53 @@ object ClusterLogMarker { val singletonTerminated: LogMarker = LogMarker("akkaClusterSingletonTerminated") + /** + * Marker "akkaSbrDowning" of log event when Split Brain Resolver has made a downing decision. Followed + * by [[ClusterLogMarker.sbrDowningNode]] for each node that is downed. + * @param decision The downing decision. Included as property "akkaSbrDecision". + */ + def sbrDowning(decision: DowningStrategy.Decision): LogMarker = + LogMarker("akkaSbrDowning", Map(Properties.SbrDecision -> decision)) + + /** + * Marker "akkaSbrDowningNode" of log event when a member is downed by Split Brain Resolver. + * @param node The address of the node that is downed. Included as property "akkaRemoteAddress" + * and "akkaRemoteAddressUid". + * @param decision The downing decision. Included as property "akkaSbrDecision". + */ + def sbrDowningNode(node: UniqueAddress, decision: DowningStrategy.Decision): LogMarker = + LogMarker( + "akkaSbrDowningNode", + Map( + LogMarker.Properties.RemoteAddress -> node.address, + LogMarker.Properties.RemoteAddressUid -> node.longUid, + Properties.SbrDecision -> decision)) + + /** + * Marker "akkaSbrInstability" of log event when Split Brain Resolver has detected too much instability + * and will down all nodes. + */ + val sbrInstability: LogMarker = + LogMarker("akkaSbrInstability") + + /** + * Marker "akkaSbrLeaseAcquired" of log event when Split Brain Resolver has acquired the lease. + * @param decision The downing decision. Included as property "akkaSbrDecision". + */ + def sbrLeaseAcquired(decision: DowningStrategy.Decision): LogMarker = + LogMarker("akkaSbrLeaseAcquired", Map(Properties.SbrDecision -> decision)) + + /** + * Marker "akkaSbrLeaseDenied" of log event when Split Brain Resolver has acquired the lease. + * @param reverseDecision The (reverse) downing decision. Included as property "akkaSbrDecision". + */ + def sbrLeaseDenied(reverseDecision: DowningStrategy.Decision): LogMarker = + LogMarker("akkaSbrLeaseDenied", Map(Properties.SbrDecision -> reverseDecision)) + + /** + * Marker "akkaSbrLeaseReleased" of log event when Split Brain Resolver has released the lease. + */ + val sbrLeaseReleased: LogMarker = + LogMarker("akkaSbrLeaseReleased") + } diff --git a/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala b/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala index 2958e8e172..fc7a7770ca 100644 --- a/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala +++ b/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala @@ -70,6 +70,9 @@ private[cluster] class ClusterRemoteWatcher( override val log = Logging(context.system, ActorWithLogClass(this, ClusterLogClass.ClusterCore)) + // allowed to watch even though address not in cluster membership, i.e. remote watch + private val watchPathWhitelist = Set("/system/sharding/") + private var pendingDelayedQuarantine: Set[UniqueAddress] = Set.empty var clusterNodes: Set[Address] = Set.empty @@ -164,7 +167,19 @@ private[cluster] class ClusterRemoteWatcher( if (!clusterNodes(watchee.path.address)) super.watchNode(watchee) override protected def shouldWatch(watchee: InternalActorRef): Boolean = - clusterNodes(watchee.path.address) || super.shouldWatch(watchee) + clusterNodes(watchee.path.address) || super.shouldWatch(watchee) || isWatchOutsideClusterAllowed(watchee) + + /** + * Allowed to watch some paths even though address not in cluster membership, i.e. remote watch. + * Needed for ShardCoordinator that has to watch old incarnations of region ActorRef from the + * recovered state. + */ + private def isWatchOutsideClusterAllowed(watchee: InternalActorRef): Boolean = { + context.system.name == watchee.path.address.system && { + val pathPrefix = watchee.path.elements.take(2).mkString("/", "/", "/") + watchPathWhitelist.contains(pathPrefix) + } + } /** * When a cluster node is added this class takes over the diff --git a/akka-cluster/src/main/scala/akka/cluster/sbr/DowningStrategy.scala b/akka-cluster/src/main/scala/akka/cluster/sbr/DowningStrategy.scala index 72028b7a2e..bf370a0ea4 100644 --- a/akka-cluster/src/main/scala/akka/cluster/sbr/DowningStrategy.scala +++ b/akka-cluster/src/main/scala/akka/cluster/sbr/DowningStrategy.scala @@ -10,6 +10,7 @@ import scala.concurrent.duration.FiniteDuration import akka.actor.Address import akka.annotation.InternalApi +import akka.annotation.InternalStableApi import akka.cluster.ClusterSettings.DataCenter import akka.cluster.Member import akka.cluster.MemberStatus @@ -20,7 +21,7 @@ import akka.coordination.lease.scaladsl.Lease /** * INTERNAL API */ -@InternalApi private[sbr] object DowningStrategy { +@InternalApi private[akka] object DowningStrategy { sealed trait Decision { def isIndirectlyConnected: Boolean } @@ -53,12 +54,13 @@ import akka.coordination.lease.scaladsl.Lease /** * INTERNAL API */ -@InternalApi private[sbr] abstract class DowningStrategy(val selfDc: DataCenter) { +@InternalApi private[akka] abstract class DowningStrategy(val selfDc: DataCenter) { import DowningStrategy._ // may contain Joining and WeaklyUp private var _unreachable: Set[UniqueAddress] = Set.empty[UniqueAddress] + @InternalStableApi def unreachable: Set[UniqueAddress] = _unreachable def unreachable(m: Member): Boolean = _unreachable(m.uniqueAddress) @@ -79,11 +81,13 @@ import akka.coordination.lease.scaladsl.Lease _allMembers.filter(m => m.status == MemberStatus.Joining || m.status == MemberStatus.WeaklyUp) // all members in self DC, both joining and up. + @InternalStableApi def allMembersInDC: immutable.SortedSet[Member] = _allMembers /** * All members in self DC, but doesn't contain Joining, WeaklyUp, Down and Exiting. */ + @InternalStableApi def members: immutable.SortedSet[Member] = members(includingPossiblyUp = false, excludingPossiblyExiting = false) @@ -193,6 +197,7 @@ import akka.coordination.lease.scaladsl.Lease } } + @InternalStableApi def reachability: Reachability = _reachability diff --git a/akka-cluster/src/main/scala/akka/cluster/sbr/SplitBrainResolver.scala b/akka-cluster/src/main/scala/akka/cluster/sbr/SplitBrainResolver.scala index 4cdb9137b8..4545836d2f 100644 --- a/akka-cluster/src/main/scala/akka/cluster/sbr/SplitBrainResolver.scala +++ b/akka-cluster/src/main/scala/akka/cluster/sbr/SplitBrainResolver.scala @@ -11,20 +11,24 @@ import scala.concurrent.ExecutionContext import scala.concurrent.duration._ import akka.actor.Actor -import akka.actor.ActorLogging import akka.actor.Address import akka.actor.ExtendedActorSystem import akka.actor.Props import akka.actor.Stash import akka.actor.Timers import akka.annotation.InternalApi +import akka.annotation.InternalStableApi import akka.cluster.Cluster import akka.cluster.ClusterEvent import akka.cluster.ClusterEvent._ +import akka.cluster.ClusterLogMarker import akka.cluster.ClusterSettings.DataCenter import akka.cluster.Member import akka.cluster.Reachability import akka.cluster.UniqueAddress +import akka.cluster.sbr.DowningStrategy.Decision +import akka.event.DiagnosticMarkerBusLoggingAdapter +import akka.event.Logging import akka.pattern.pipe /** @@ -97,7 +101,7 @@ import akka.pattern.pipe log.info( "SBR started. Config: stableAfter: {} ms, strategy: {}, selfUniqueAddress: {}, selfDc: {}", stableAfter.toMillis, - strategy.getClass.getSimpleName, + Logging.simpleName(strategy.getClass), selfUniqueAddress, selfDc) @@ -114,8 +118,9 @@ import akka.pattern.pipe super.postStop() } - override def down(node: Address): Unit = { - cluster.down(node) + override def down(node: UniqueAddress, decision: Decision): Unit = { + log.info(ClusterLogMarker.sbrDowningNode(node, decision), "SBR is downing [{}]", node) + cluster.down(node.address) } } @@ -126,9 +131,8 @@ import akka.pattern.pipe * The implementation is split into two classes SplitBrainResolver and SplitBrainResolverBase to be * able to unit test the logic without running cluster. */ -@InternalApi private[sbr] abstract class SplitBrainResolverBase(stableAfter: FiniteDuration, strategy: DowningStrategy) +@InternalApi private[sbr] abstract class SplitBrainResolverBase(stableAfter: FiniteDuration, _strategy: DowningStrategy) extends Actor - with ActorLogging with Stash with Timers { @@ -136,11 +140,17 @@ import akka.pattern.pipe import SplitBrainResolver.ReleaseLeaseCondition.NoLease import SplitBrainResolver._ + val log: DiagnosticMarkerBusLoggingAdapter = Logging.withMarker(this) + + @InternalStableApi + def strategy: DowningStrategy = _strategy + + @InternalStableApi def selfUniqueAddress: UniqueAddress def selfDc: DataCenter - def down(node: Address): Unit + def down(node: UniqueAddress, decision: Decision): Unit // would be better as constructor parameter, but don't want to break Cinnamon instrumentation private val settings = new SplitBrainResolverSettings(context.system.settings.config) @@ -289,7 +299,10 @@ import akka.pattern.pipe resetReachabilityChangedStats() } else if (downAllWhenUnstable > Duration.Zero && durationSinceFirstChange > (stableAfter + downAllWhenUnstable)) { - log.warning("SBR detected instability and will down all nodes: {}", reachabilityChangedStats) + log.warning( + ClusterLogMarker.sbrInstability, + "SBR detected instability and will down all nodes: {}", + reachabilityChangedStats) actOnDecision(DownAll) } } @@ -300,7 +313,10 @@ import akka.pattern.pipe strategy.lease match { case Some(lease) => if (lease.checkLease()) { - log.info("SBR has acquired lease for decision [{}]", decision) + log.info( + ClusterLogMarker.sbrLeaseAcquired(decision), + "SBR has acquired lease for decision [{}]", + decision) actOnDecision(decision) } else { if (decision.acquireDelay == Duration.Zero) @@ -349,7 +365,7 @@ import akka.pattern.pipe case AcquireLeaseResult(holdingLease) => if (holdingLease) { - log.info("SBR acquired lease for decision [{}]", decision) + log.info(ClusterLogMarker.sbrLeaseAcquired(decision), "SBR acquired lease for decision [{}]", decision) val downedNodes = actOnDecision(decision) releaseLeaseCondition = releaseLeaseCondition match { case ReleaseLeaseCondition.WhenMembersRemoved(nodes) => @@ -362,7 +378,11 @@ import akka.pattern.pipe } } else { val reverseDecision = strategy.reverseDecision(decision) - log.info("SBR couldn't acquire lease, reverse decision [{}] to [{}]", decision, reverseDecision) + log.info( + ClusterLogMarker.sbrLeaseDenied(reverseDecision), + "SBR couldn't acquire lease, reverse decision [{}] to [{}]", + decision, + reverseDecision) actOnDecision(reverseDecision) releaseLeaseCondition = NoLease } @@ -379,8 +399,10 @@ import akka.pattern.pipe private def releaseLeaseResult(released: Boolean): Unit = { releaseLeaseCondition match { case ReleaseLeaseCondition.WhenTimeElapsed(deadline) => - if (released && deadline.isOverdue()) + if (released && deadline.isOverdue()) { + log.info(ClusterLogMarker.sbrLeaseReleased, "SBR released lease.") releaseLeaseCondition = NoLease // released successfully + } case _ => // no lease or first waiting for downed nodes to be removed } @@ -399,6 +421,27 @@ import akka.pattern.pipe strategy.nodesToDown(DownAll) } + observeDecision(decision, nodesToDown, unreachableDataCenters) + + if (nodesToDown.nonEmpty) { + val downMyself = nodesToDown.contains(selfUniqueAddress) + // downing is idempotent, and we also avoid calling down on nodes with status Down + // down selfAddress last, since it may shutdown itself if down alone + nodesToDown.foreach(uniqueAddress => if (uniqueAddress != selfUniqueAddress) down(uniqueAddress, decision)) + if (downMyself) + down(selfUniqueAddress, decision) + + resetReachabilityChangedStats() + resetStableDeadline() + } + nodesToDown + } + + @InternalStableApi + def observeDecision( + decision: Decision, + nodesToDown: Set[UniqueAddress], + unreachableDataCenters: Set[DataCenter]): Unit = { val downMyself = nodesToDown.contains(selfUniqueAddress) val indirectlyConnectedLogMessage = @@ -411,24 +454,13 @@ import akka.pattern.pipe else "" log.warning( + ClusterLogMarker.sbrDowning(decision), s"SBR took decision $decision and is downing [${nodesToDown.map(_.address).mkString(", ")}]${if (downMyself) " including myself," else ""}, " + s"[${strategy.unreachable.size}] unreachable of [${strategy.members.size}] members" + indirectlyConnectedLogMessage + s", all members in DC [${strategy.allMembersInDC.mkString(", ")}], full reachability status: ${strategy.reachability}" + unreachableDataCentersLogMessage) - - if (nodesToDown.nonEmpty) { - // downing is idempotent, and we also avoid calling down on nodes with status Down - // down selfAddress last, since it may shutdown itself if down alone - nodesToDown.foreach(uniqueAddress => if (uniqueAddress != selfUniqueAddress) down(uniqueAddress.address)) - if (downMyself) - down(selfUniqueAddress.address) - - resetReachabilityChangedStats() - resetStableDeadline() - } - nodesToDown } def isResponsible: Boolean = leader && selfMemberAdded @@ -484,7 +516,7 @@ import akka.pattern.pipe def reachableDataCenter(dc: DataCenter): Unit = { unreachableDataCenters -= dc - log.info("Data center [] observed as reachable again", dc) + log.info("Data center [{}] observed as reachable again", dc) } def seenChanged(seenBy: Set[Address]): Unit = { @@ -569,7 +601,7 @@ import akka.pattern.pipe implicit val ec: ExecutionContext = internalDispatcher strategy.lease.foreach { l => if (releaseLeaseCondition != NoLease) { - log.info("SBR releasing lease") + log.debug("SBR releasing lease") l.release().recover { case _ => false }.map(ReleaseLeaseResult.apply).pipeTo(self) } } diff --git a/akka-cluster/src/test/scala/akka/cluster/sbr/SplitBrainResolverSpec.scala b/akka-cluster/src/test/scala/akka/cluster/sbr/SplitBrainResolverSpec.scala index a60cde4113..cae903c8ef 100644 --- a/akka-cluster/src/test/scala/akka/cluster/sbr/SplitBrainResolverSpec.scala +++ b/akka-cluster/src/test/scala/akka/cluster/sbr/SplitBrainResolverSpec.scala @@ -75,10 +75,10 @@ object SplitBrainResolverSpec { var downed = Set.empty[Address] - override def down(node: Address): Unit = { - if (leader && !downed(node)) { - downed += node - probe ! DownCalled(node) + override def down(node: UniqueAddress, decision: DowningStrategy.Decision): Unit = { + if (leader && !downed(node.address)) { + downed += node.address + probe ! DownCalled(node.address) } else if (!leader) probe ! "down must only be done by leader" } diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ExternalInteractions.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ExternalInteractions.scala index abe744434d..4b8298f366 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ExternalInteractions.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ExternalInteractions.scala @@ -99,7 +99,7 @@ private[akka] trait JournalInteractions[C, E, S] { @unused repr: immutable.Seq[PersistentRepr]): Unit = () protected def replayEvents(fromSeqNr: Long, toSeqNr: Long): Unit = { - setup.log.debug2("Replaying messages: from: {}, to: {}", fromSeqNr, toSeqNr) + setup.log.debug2("Replaying events: from: {}, to: {}", fromSeqNr, toSeqNr) setup.journal.tell( ReplayMessages(fromSeqNr, toSeqNr, setup.recovery.replayMax, setup.persistenceId.id, setup.selfClassic), setup.selfClassic) diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala index da06fb1294..c61da7ce27 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala @@ -116,10 +116,8 @@ private[akka] final class ReplayingEvents[C, E, S]( def handleEvent(event: E): Unit = { eventForErrorReporting = OptionVal.Some(event) - state = state.copy( - seqNr = repr.sequenceNr, - state = setup.eventHandler(state.state, event), - eventSeenInInterval = true) + state = state.copy(seqNr = repr.sequenceNr) + state = state.copy(state = setup.eventHandler(state.state, event), eventSeenInInterval = true) } eventSeq match { @@ -247,5 +245,6 @@ private[akka] final class ReplayingEvents[C, E, S]( setup.cancelRecoveryTimer() } - override def currentSequenceNumber: Long = state.seqNr + override def currentSequenceNumber: Long = + state.seqNr } diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala index a9174e5440..c1f4173801 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala @@ -97,10 +97,15 @@ private[akka] object Running { import InternalProtocol._ import Running.RunningState + // Needed for WithSeqNrAccessible, when unstashing + private var _currentSequenceNumber = 0L + final class HandlingCommands(state: RunningState[S]) extends AbstractBehavior[InternalProtocol](setup.context) with WithSeqNrAccessible { + _currentSequenceNumber = state.seqNr + def onMessage(msg: InternalProtocol): Behavior[InternalProtocol] = msg match { case IncomingCommand(c: C @unchecked) => onCommand(state, c) case JournalResponse(r) => onDeleteEventsJournalResponse(r, state.state) @@ -150,6 +155,7 @@ private[akka] object Running { // apply the event before persist so that validation exception is handled before persisting // the invalid event, in case such validation is implemented in the event handler. // also, ensure that there is an event handler for each single event + _currentSequenceNumber = state.seqNr + 1 val newState = state.applyEvent(setup, event) val eventToPersist = adaptEvent(event) @@ -166,12 +172,13 @@ private[akka] object Running { // apply the event before persist so that validation exception is handled before persisting // the invalid event, in case such validation is implemented in the event handler. // also, ensure that there is an event handler for each single event - var seqNr = state.seqNr + _currentSequenceNumber = state.seqNr val (newState, shouldSnapshotAfterPersist) = events.foldLeft((state, NoSnapshot: SnapshotAfterPersist)) { case ((currentState, snapshot), event) => - seqNr += 1 + _currentSequenceNumber += 1 val shouldSnapshot = - if (snapshot == NoSnapshot) setup.shouldSnapshot(currentState.state, event, seqNr) else snapshot + if (snapshot == NoSnapshot) setup.shouldSnapshot(currentState.state, event, _currentSequenceNumber) + else snapshot (currentState.applyEvent(setup, event), shouldSnapshot) } @@ -212,7 +219,8 @@ private[akka] object Running { setup.setMdcPhase(PersistenceMdc.RunningCmds) - override def currentSequenceNumber: Long = state.seqNr + override def currentSequenceNumber: Long = + _currentSequenceNumber } // =============================================== @@ -335,7 +343,9 @@ private[akka] object Running { else Behaviors.unhandled } - override def currentSequenceNumber: Long = visibleState.seqNr + override def currentSequenceNumber: Long = { + _currentSequenceNumber + } } // =============================================== @@ -430,7 +440,8 @@ private[akka] object Running { Behaviors.unhandled } - override def currentSequenceNumber: Long = state.seqNr + override def currentSequenceNumber: Long = + _currentSequenceNumber } // -------------------------- diff --git a/akka-persistence-typed/src/test/java/akka/persistence/typed/javadsl/PersistentActorJavaDslTest.java b/akka-persistence-typed/src/test/java/akka/persistence/typed/javadsl/PersistentActorJavaDslTest.java index 45a57a3df4..57e28b85ff 100644 --- a/akka-persistence-typed/src/test/java/akka/persistence/typed/javadsl/PersistentActorJavaDslTest.java +++ b/akka-persistence-typed/src/test/java/akka/persistence/typed/javadsl/PersistentActorJavaDslTest.java @@ -773,7 +773,7 @@ public class PersistentActorJavaDslTest extends JUnitSuite { probe.expectMessage("0 onRecoveryCompleted"); ref.tell("cmd"); probe.expectMessage("0 onCommand"); - probe.expectMessage("0 applyEvent"); + probe.expectMessage("1 applyEvent"); probe.expectMessage("1 thenRun"); } } diff --git a/akka-persistence-typed/src/test/scala/akka/persistence/typed/scaladsl/EventSourcedSequenceNumberSpec.scala b/akka-persistence-typed/src/test/scala/akka/persistence/typed/scaladsl/EventSourcedSequenceNumberSpec.scala index 3c3947ab30..222f25fe97 100644 --- a/akka-persistence-typed/src/test/scala/akka/persistence/typed/scaladsl/EventSourcedSequenceNumberSpec.scala +++ b/akka-persistence-typed/src/test/scala/akka/persistence/typed/scaladsl/EventSourcedSequenceNumberSpec.scala @@ -49,7 +49,12 @@ class EventSourcedSequenceNumberSpec case "cmd" => probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} onCommand" Effect - .persist(command) + .persist("evt") + .thenRun(_ => probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} thenRun") + case "cmd3" => + probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} onCommand" + Effect + .persist("evt1", "evt2", "evt3") .thenRun(_ => probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} thenRun") case "stash" => probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} stash" @@ -59,7 +64,7 @@ class EventSourcedSequenceNumberSpec } } }, { (_, evt) => - probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} eventHandler" + probe ! s"${EventSourcedBehavior.lastSequenceNumber(ctx)} eventHandler $evt" evt }).snapshotWhen((_, event, _) => event == "snapshot").receiveSignal { case (_, RecoveryCompleted) => @@ -75,11 +80,40 @@ class EventSourcedSequenceNumberSpec ref ! "cmd" probe.expectMessage("0 onCommand") - probe.expectMessage("0 eventHandler") + probe.expectMessage("1 eventHandler evt") probe.expectMessage("1 thenRun") + + ref ! "cmd" + probe.expectMessage("1 onCommand") + probe.expectMessage("2 eventHandler evt") + probe.expectMessage("2 thenRun") + + ref ! "cmd3" + probe.expectMessage("2 onCommand") + probe.expectMessage("3 eventHandler evt1") + probe.expectMessage("4 eventHandler evt2") + probe.expectMessage("5 eventHandler evt3") + probe.expectMessage("5 thenRun") + + testKit.stop(ref) + probe.expectTerminated(ref) + + // and during replay + val ref2 = spawn(behavior(PersistenceId.ofUniqueId("ess-1"), probe.ref)) + probe.expectMessage("1 eventHandler evt") + probe.expectMessage("2 eventHandler evt") + probe.expectMessage("3 eventHandler evt1") + probe.expectMessage("4 eventHandler evt2") + probe.expectMessage("5 eventHandler evt3") + probe.expectMessage("5 onRecoveryComplete") + + ref2 ! "cmd" + probe.expectMessage("5 onCommand") + probe.expectMessage("6 eventHandler evt") + probe.expectMessage("6 thenRun") } - "be available while replaying stash" in { + "be available while unstashing" in { val probe = TestProbe[String]() val ref = spawn(behavior(PersistenceId.ofUniqueId("ess-2"), probe.ref)) probe.expectMessage("0 onRecoveryComplete") @@ -87,18 +121,23 @@ class EventSourcedSequenceNumberSpec ref ! "stash" ref ! "cmd" ref ! "cmd" - ref ! "cmd" + ref ! "cmd3" ref ! "unstash" probe.expectMessage("0 stash") - probe.expectMessage("0 eventHandler") + probe.expectMessage("1 eventHandler stashing") probe.expectMessage("1 unstash") - probe.expectMessage("1 eventHandler") + probe.expectMessage("2 eventHandler normal") probe.expectMessage("2 onCommand") - probe.expectMessage("2 eventHandler") + probe.expectMessage("3 eventHandler evt") probe.expectMessage("3 thenRun") probe.expectMessage("3 onCommand") - probe.expectMessage("3 eventHandler") + probe.expectMessage("4 eventHandler evt") probe.expectMessage("4 thenRun") + probe.expectMessage("4 onCommand") // cmd3 + probe.expectMessage("5 eventHandler evt1") + probe.expectMessage("6 eventHandler evt2") + probe.expectMessage("7 eventHandler evt3") + probe.expectMessage("7 thenRun") } // reproducer for #27935 @@ -112,11 +151,11 @@ class EventSourcedSequenceNumberSpec ref ! "cmd" probe.expectMessage("0 onCommand") // first command - probe.expectMessage("0 eventHandler") + probe.expectMessage("1 eventHandler evt") probe.expectMessage("1 thenRun") - probe.expectMessage("1 eventHandler") // snapshot + probe.expectMessage("2 eventHandler snapshot") probe.expectMessage("2 onCommand") // second command - probe.expectMessage("2 eventHandler") + probe.expectMessage("3 eventHandler evt") probe.expectMessage("3 thenRun") } } diff --git a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala index b83ab570d5..889c37a174 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala @@ -715,7 +715,7 @@ private[akka] class RemoteActorRef private[akka] ( else if (provider.remoteWatcher.isDefined) remote.send(message, OptionVal.None, this) else - provider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "remote Watch") + provider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Watch") //Unwatch has a different signature, need to pattern match arguments against InternalActorRef case Unwatch(watchee: InternalActorRef, watcher: InternalActorRef) => diff --git a/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala b/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala index a727ef21aa..5eaf45a864 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala @@ -221,7 +221,7 @@ private[akka] class RemoteWatcher( // add watch from self, this will actually send a Watch to the target when necessary context.watch(watchee) - } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watcher, watchee, "Watch") + } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Watch") } def watchNode(watchee: InternalActorRef): Unit = { @@ -250,7 +250,7 @@ private[akka] class RemoteWatcher( } case None => } - } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watcher, watchee, "Unwatch") + } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Unwatch") } def removeWatchee(watchee: InternalActorRef): Unit = { diff --git a/akka-remote/src/main/scala/akka/remote/artery/RemotingFlightRecorder.scala b/akka-remote/src/main/scala/akka/remote/artery/RemotingFlightRecorder.scala index 9363f3f7f0..41fefdf3cf 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/RemotingFlightRecorder.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/RemotingFlightRecorder.scala @@ -6,7 +6,11 @@ package akka.remote.artery import java.net.InetSocketAddress -import akka.actor.{ Address, ExtendedActorSystem, Extension, ExtensionId, ExtensionIdProvider } +import akka.actor.Address +import akka.actor.ExtendedActorSystem +import akka.actor.Extension +import akka.actor.ExtensionId +import akka.actor.ExtensionIdProvider import akka.annotation.InternalApi import akka.remote.UniqueAddress import akka.util.FlightRecorderLoader diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala index 2c737b085c..742a29d92f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -33,10 +33,10 @@ object TlsSpec { val rnd = new Random - val SSLEnabledAlgorithms: Set[String] = Set("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA") - val SSLProtocol: String = "TLSv1.2" + val TLS12Ciphers: Set[String] = Set("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA") + val TLS13Ciphers: Set[String] = Set("TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384") - def initWithTrust(trustPath: String): SSLContext = { + def initWithTrust(trustPath: String, protocol: String): SSLContext = { val password = "changeme" val keyStore = KeyStore.getInstance(KeyStore.getDefaultType) @@ -51,12 +51,12 @@ object TlsSpec { val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) trustManagerFactory.init(trustStore) - val context = SSLContext.getInstance(SSLProtocol) + val context = SSLContext.getInstance(protocol) context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom) context } - def initSslContext(): SSLContext = initWithTrust("/truststore") + def initSslContext(protocol: String): SSLContext = initWithTrust("/truststore", protocol) /** * This is an operator that fires a TimeoutException failure 2 seconds after it was started, @@ -102,469 +102,486 @@ class TlsSpec extends StreamSpec(TlsSpec.configOverrides) with WithLogCapturing import system.dispatcher "SslTls" must { + "work for TLSv1.2" must { workFor("TLSv1.2", TLS12Ciphers) } - val sslContext = initSslContext() + if (JavaVersion.majorVersion >= 11) + "work for TLSv1.3" must { workFor("TLSv1.3", TLS13Ciphers) } - val debug = Flow[SslTlsInbound].map { x => - x match { - case SessionTruncated => system.log.debug(s" ----------- truncated ") - case SessionBytes(_, b) => system.log.debug(s" ----------- (${b.size}) ${b.take(32).utf8String}") - } - x - } + def workFor(protocol: String, ciphers: Set[String]): Unit = { + val sslContext = initSslContext(protocol) - def createSSLEngine(context: SSLContext, role: TLSRole): SSLEngine = - createSSLEngine2(context, role, hostnameVerification = false, hostInfo = None) - - def createSSLEngine2( - context: SSLContext, - role: TLSRole, - hostnameVerification: Boolean, - hostInfo: Option[(String, Int)]): SSLEngine = { - - val engine = hostInfo match { - case None => - if (hostnameVerification) - throw new IllegalArgumentException("hostInfo must be defined for hostnameVerification to work.") - context.createSSLEngine() - case Some((hostname, port)) => context.createSSLEngine(hostname, port) - } - - if (hostnameVerification && role == akka.stream.Client) { - val sslParams = sslContext.getDefaultSSLParameters - sslParams.setEndpointIdentificationAlgorithm("HTTPS") - engine.setSSLParameters(sslParams) - } - - engine.setUseClientMode(role == akka.stream.Client) - engine.setEnabledCipherSuites(SSLEnabledAlgorithms.toArray) - engine.setEnabledProtocols(Array(SSLProtocol)) - - engine - } - - def clientTls(closing: TLSClosing) = - TLS(() => createSSLEngine(sslContext, Client), closing) - def badClientTls(closing: TLSClosing) = - TLS(() => createSSLEngine(initWithTrust("/badtruststore"), Client), closing) - def serverTls(closing: TLSClosing) = - TLS(() => createSSLEngine(sslContext, Server), closing) - - trait Named { - def name: String = - getClass.getName.reverse.dropWhile(c => "$0123456789".indexOf(c) != -1).takeWhile(_ != '$').reverse - } - - trait CommunicationSetup extends Named { - def decorateFlow( - leftClosing: TLSClosing, - rightClosing: TLSClosing, - rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]): Flow[SslTlsOutbound, SslTlsInbound, NotUsed] - def cleanup(): Unit = () - } - - object ClientInitiates extends CommunicationSetup { - def decorateFlow( - leftClosing: TLSClosing, - rightClosing: TLSClosing, - rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = - clientTls(leftClosing).atop(serverTls(rightClosing).reversed).join(rhs) - } - - object ServerInitiates extends CommunicationSetup { - def decorateFlow( - leftClosing: TLSClosing, - rightClosing: TLSClosing, - rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = - serverTls(leftClosing).atop(clientTls(rightClosing).reversed).join(rhs) - } - - def server(flow: Flow[ByteString, ByteString, Any]) = { - val server = Tcp().bind("localhost", 0).to(Sink.foreach(c => c.flow.join(flow).run())).run() - Await.result(server, 2.seconds) - } - - object ClientInitiatesViaTcp extends CommunicationSetup { - var binding: Tcp.ServerBinding = null - def decorateFlow( - leftClosing: TLSClosing, - rightClosing: TLSClosing, - rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { - binding = server(serverTls(rightClosing).reversed.join(rhs)) - clientTls(leftClosing).join(Tcp().outgoingConnection(binding.localAddress)) - } - override def cleanup(): Unit = binding.unbind() - } - - object ServerInitiatesViaTcp extends CommunicationSetup { - var binding: Tcp.ServerBinding = null - def decorateFlow( - leftClosing: TLSClosing, - rightClosing: TLSClosing, - rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { - binding = server(clientTls(rightClosing).reversed.join(rhs)) - serverTls(leftClosing).join(Tcp().outgoingConnection(binding.localAddress)) - } - override def cleanup(): Unit = binding.unbind() - } - - val communicationPatterns = - Seq(ClientInitiates, ServerInitiates, ClientInitiatesViaTcp, ServerInitiatesViaTcp) - - trait PayloadScenario extends Named { - def flow: Flow[SslTlsInbound, SslTlsOutbound, Any] = - Flow[SslTlsInbound].map { - var session: SSLSession = null - def setSession(s: SSLSession) = { - session = s - system.log.debug(s"new session: $session (${session.getId.mkString(",")})") - } - - { - case SessionTruncated => SendBytes(ByteString("TRUNCATED")) - case SessionBytes(s, b) if session == null => - setSession(s) - SendBytes(b) - case SessionBytes(s, b) if s != session => - setSession(s) - SendBytes(ByteString("NEWSESSION") ++ b) - case SessionBytes(_, b) => SendBytes(b) - } + val debug = Flow[SslTlsInbound].map { x => + x match { + case SessionTruncated => system.log.debug(s" ----------- truncated ") + case SessionBytes(_, b) => system.log.debug(s" ----------- (${b.size}) ${b.take(32).utf8String}") } - def leftClosing: TLSClosing = IgnoreComplete - def rightClosing: TLSClosing = IgnoreComplete + x + } - def inputs: immutable.Seq[SslTlsOutbound] - def output: ByteString + def createSSLEngine(context: SSLContext, role: TLSRole): SSLEngine = + createSSLEngine2(context, role, hostnameVerification = false, hostInfo = None) - protected def send(str: String) = SendBytes(ByteString(str)) - protected def send(ch: Char) = SendBytes(ByteString(ch.toByte)) - } + def createSSLEngine2( + context: SSLContext, + role: TLSRole, + hostnameVerification: Boolean, + hostInfo: Option[(String, Int)]): SSLEngine = { - object SingleBytes extends PayloadScenario { - val str = "0123456789" - def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) - def output = ByteString(str) - } + val engine = hostInfo match { + case None => + if (hostnameVerification) + throw new IllegalArgumentException("hostInfo must be defined for hostnameVerification to work.") + context.createSSLEngine() + case Some((hostname, port)) => context.createSSLEngine(hostname, port) + } - object MediumMessages extends PayloadScenario { - val strs = "0123456789".map(d => d.toString * (rnd.nextInt(9000) + 1000)) - def inputs = strs.map(s => SendBytes(ByteString(s))) - def output = ByteString(strs.foldRight("")(_ ++ _)) - } + if (hostnameVerification && role == akka.stream.Client) { + val sslParams = sslContext.getDefaultSSLParameters + sslParams.setEndpointIdentificationAlgorithm("HTTPS") + engine.setSSLParameters(sslParams) + } - object LargeMessages extends PayloadScenario { - // TLS max packet size is 16384 bytes - val strs = "0123456789".map(d => d.toString * (rnd.nextInt(9000) + 17000)) - def inputs = strs.map(s => SendBytes(ByteString(s))) - def output = ByteString(strs.foldRight("")(_ ++ _)) - } + engine.setUseClientMode(role == akka.stream.Client) + engine.setEnabledCipherSuites(ciphers.toArray) + engine.setEnabledProtocols(Array(protocol)) - object EmptyBytesFirst extends PayloadScenario { - def inputs = List(ByteString.empty, ByteString("hello")).map(SendBytes) - def output = ByteString("hello") - } + engine + } - object EmptyBytesInTheMiddle extends PayloadScenario { - def inputs = List(ByteString("hello"), ByteString.empty, ByteString(" world")).map(SendBytes) - def output = ByteString("hello world") - } + def clientTls(closing: TLSClosing) = + TLS(() => createSSLEngine(sslContext, Client), closing) - object EmptyBytesLast extends PayloadScenario { - def inputs = List(ByteString("hello"), ByteString.empty).map(SendBytes) - def output = ByteString("hello") - } + def badClientTls(closing: TLSClosing) = + TLS(() => createSSLEngine(initWithTrust("/badtruststore", protocol), Client), closing) - object CompletedImmediately extends PayloadScenario { - override def inputs: immutable.Seq[SslTlsOutbound] = Nil - override def output = ByteString.empty + def serverTls(closing: TLSClosing) = + TLS(() => createSSLEngine(sslContext, Server), closing) - override def leftClosing: TLSClosing = EagerClose - override def rightClosing: TLSClosing = EagerClose - } + trait Named { + def name: String = + getClass.getName.reverse.dropWhile(c => "$0123456789".indexOf(c) != -1).takeWhile(_ != '$').reverse + } - // this demonstrates that cancellation is ignored so that the five results make it back - object CancellingRHS extends PayloadScenario { - override def flow = - Flow[SslTlsInbound] - .mapConcat { - case SessionTruncated => SessionTruncated :: Nil - case SessionBytes(s, bytes) => bytes.map(b => SessionBytes(s, ByteString(b))) + trait CommunicationSetup extends Named { + def decorateFlow( + leftClosing: TLSClosing, + rightClosing: TLSClosing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]): Flow[SslTlsOutbound, SslTlsInbound, NotUsed] + def cleanup(): Unit = () + } + + object ClientInitiates extends CommunicationSetup { + def decorateFlow( + leftClosing: TLSClosing, + rightClosing: TLSClosing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = + clientTls(leftClosing).atop(serverTls(rightClosing).reversed).join(rhs) + } + + object ServerInitiates extends CommunicationSetup { + def decorateFlow( + leftClosing: TLSClosing, + rightClosing: TLSClosing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = + serverTls(leftClosing).atop(clientTls(rightClosing).reversed).join(rhs) + } + + def server(flow: Flow[ByteString, ByteString, Any]) = { + val server = Tcp().bind("localhost", 0).to(Sink.foreach(c => c.flow.join(flow).run())).run() + Await.result(server, 2.seconds) + } + + object ClientInitiatesViaTcp extends CommunicationSetup { + var binding: Tcp.ServerBinding = null + def decorateFlow( + leftClosing: TLSClosing, + rightClosing: TLSClosing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { + binding = server(serverTls(rightClosing).reversed.join(rhs)) + clientTls(leftClosing).join(Tcp().outgoingConnection(binding.localAddress)) + } + override def cleanup(): Unit = binding.unbind() + } + + object ServerInitiatesViaTcp extends CommunicationSetup { + var binding: Tcp.ServerBinding = null + def decorateFlow( + leftClosing: TLSClosing, + rightClosing: TLSClosing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { + binding = server(clientTls(rightClosing).reversed.join(rhs)) + serverTls(leftClosing).join(Tcp().outgoingConnection(binding.localAddress)) + } + override def cleanup(): Unit = binding.unbind() + } + + val communicationPatterns = + Seq(ClientInitiates, ServerInitiates, ClientInitiatesViaTcp, ServerInitiatesViaTcp) + + trait PayloadScenario extends Named { + def flow: Flow[SslTlsInbound, SslTlsOutbound, Any] = + Flow[SslTlsInbound].map { + var session: SSLSession = null + def setSession(s: SSLSession) = { + session = s + system.log.debug(s"new session: $session (${session.getId.mkString(",")})") + } + + { + case SessionTruncated => SendBytes(ByteString("TRUNCATED")) + case SessionBytes(s, b) if session == null => + setSession(s) + SendBytes(b) + case SessionBytes(s, b) if s != session => + setSession(s) + SendBytes(ByteString("NEWSESSION") ++ b) + case SessionBytes(_, b) => SendBytes(b) + } } - .take(5) - .mapAsync(5)(x => later(500.millis, system.scheduler)(Future.successful(x))) - .via(super.flow) - override def rightClosing = IgnoreCancel + def leftClosing: TLSClosing = IgnoreComplete + def rightClosing: TLSClosing = IgnoreComplete - val str = "abcdef" * 100 - def inputs = str.map(send) - def output = ByteString(str.take(5)) - } + def inputs: immutable.Seq[SslTlsOutbound] + def output: ByteString - object CancellingRHSIgnoresBoth extends PayloadScenario { - override def flow = - Flow[SslTlsInbound] - .mapConcat { - case SessionTruncated => SessionTruncated :: Nil - case SessionBytes(s, bytes) => bytes.map(b => SessionBytes(s, ByteString(b))) - } - .take(5) - .mapAsync(5)(x => later(500.millis, system.scheduler)(Future.successful(x))) - .via(super.flow) - override def rightClosing = IgnoreBoth - - val str = "abcdef" * 100 - def inputs = str.map(send) - def output = ByteString(str.take(5)) - } - - object LHSIgnoresBoth extends PayloadScenario { - override def leftClosing = IgnoreBoth - val str = "0123456789" - def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) - def output = ByteString(str) - } - - object BothSidesIgnoreBoth extends PayloadScenario { - override def leftClosing = IgnoreBoth - override def rightClosing = IgnoreBoth - val str = "0123456789" - def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) - def output = ByteString(str) - } - - object SessionRenegotiationBySender extends PayloadScenario { - def inputs = List(send("hello"), NegotiateNewSession, send("world")) - def output = ByteString("helloNEWSESSIONworld") - } - - // difference is that the RHS engine will now receive the handshake while trying to send - object SessionRenegotiationByReceiver extends PayloadScenario { - val str = "abcdef" * 100 - def inputs = str.map(send) ++ Seq(NegotiateNewSession) ++ "hello world".map(send) - def output = ByteString(str + "NEWSESSIONhello world") - } - - val logCipherSuite = Flow[SslTlsInbound].map { - var session: SSLSession = null - def setSession(s: SSLSession) = { - session = s - system.log.debug(s"new session: $session (${session.getId.mkString(",")})") + protected def send(str: String) = SendBytes(ByteString(str)) + protected def send(ch: Char) = SendBytes(ByteString(ch.toByte)) } - { - case SessionTruncated => SendBytes(ByteString("TRUNCATED")) - case SessionBytes(s, b) if s != session => - setSession(s) - SendBytes(ByteString(s.getCipherSuite) ++ b) - case SessionBytes(_, b) => SendBytes(b) - } - } - - object SessionRenegotiationFirstOne extends PayloadScenario { - override def flow = logCipherSuite - def inputs = NegotiateNewSession.withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil - def output = ByteString("TLS_RSA_WITH_AES_128_CBC_SHAhello") - } - - object SessionRenegotiationFirstTwo extends PayloadScenario { - override def flow = logCipherSuite - def inputs = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil - def output = ByteString("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHAhello") - } - - val scenarios = - Seq( - SingleBytes, - MediumMessages, - LargeMessages, - EmptyBytesFirst, - EmptyBytesInTheMiddle, - EmptyBytesLast, - CompletedImmediately, - CancellingRHS, - CancellingRHSIgnoresBoth, - LHSIgnoresBoth, - BothSidesIgnoreBoth, - SessionRenegotiationBySender, - SessionRenegotiationByReceiver, - SessionRenegotiationFirstOne, - SessionRenegotiationFirstTwo) - - for { - commPattern <- communicationPatterns - scenario <- scenarios - } { - s"work in mode ${commPattern.name} while sending ${scenario.name}" in assertAllStagesStopped { - val onRHS = debug.via(scenario.flow) - val output = - Source(scenario.inputs) - .via(commPattern.decorateFlow(scenario.leftClosing, scenario.rightClosing, onRHS)) - .via(new SimpleLinearGraphStage[SslTlsInbound] { - override def createLogic(inheritedAttributes: Attributes) = - new GraphStageLogic(shape) with InHandler with OutHandler { - setHandlers(in, out, this) - - override def onPush() = push(out, grab(in)) - override def onPull() = pull(in) - - override def onDownstreamFinish(cause: Throwable) = { - system.log.debug(s"me cancelled, cause {}", cause) - completeStage() - } - } - }) - .via(debug) - .collect { case SessionBytes(_, b) => b } - .scan(ByteString.empty)(_ ++ _) - .filter(_.nonEmpty) - .via(new Timeout(10.seconds)) - .dropWhile(_.size < scenario.output.size) - .runWith(Sink.headOption) - - Await.result(output, 12.seconds).getOrElse(ByteString.empty).utf8String should be(scenario.output.utf8String) - - commPattern.cleanup() - } - } - - "emit an error if the TLS handshake fails certificate checks" in assertAllStagesStopped { - val getError = Flow[SslTlsInbound] - .map[Either[SslTlsInbound, SSLException]](i => Left(i)) - .recover { case e: SSLException => Right(e) } - .collect { case Right(e) => e } - .toMat(Sink.head)(Keep.right) - - val simple = Flow.fromSinkAndSourceMat(getError, Source.maybe[SslTlsOutbound])(Keep.left) - - // The creation of actual TCP connections is necessary. It is the easiest way to decouple the client and server - // under error conditions, and has the bonus of matching most actual SSL deployments. - val (server, serverErr) = Tcp() - .bind("localhost", 0) - .mapAsync(1)(c => c.flow.joinMat(serverTls(IgnoreBoth).reversed.joinMat(simple)(Keep.right))(Keep.right).run()) - .toMat(Sink.head)(Keep.both) - .run() - - val clientErr = simple - .join(badClientTls(IgnoreBoth)) - .join(Tcp().outgoingConnection(Await.result(server, 1.second).localAddress)) - .run() - - Await.result(serverErr, 1.second).getMessage should include("certificate_unknown") - val clientErrText = Await.result(clientErr, 1.second).getMessage - if (JavaVersion.majorVersion >= 11) - clientErrText should include("unable to find valid certification path to requested target") - else - clientErrText should equal("General SSLEngine problem") - } - - "reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped { - val ex = new Exception("hello") - val (sub, out1, out2) = - RunnableGraph - .fromGraph( - GraphDSL.create(Source.asSubscriber[SslTlsOutbound], Sink.head[ByteString], Sink.head[SslTlsInbound])( - (_, _, _)) { implicit b => (s, o1, o2) => - val tls = b.add(clientTls(EagerClose)) - s ~> tls.in1; tls.out1 ~> o1 - o2 <~ tls.out2; tls.in2 <~ Source.failed(ex) - ClosedShape - }) - .run() - the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) - the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) - Thread.sleep(500) - val pub = TestPublisher.probe() - pub.subscribe(sub) - pub.expectSubscription().expectCancellation() - } - - "reliably cancel subscriptions when UserIn fails early" in assertAllStagesStopped { - val ex = new Exception("hello") - val (sub, out1, out2) = - RunnableGraph - .fromGraph(GraphDSL.create(Source.asSubscriber[ByteString], Sink.head[ByteString], Sink.head[SslTlsInbound])( - (_, _, _)) { implicit b => (s, o1, o2) => - val tls = b.add(clientTls(EagerClose)) - Source.failed[SslTlsOutbound](ex) ~> tls.in1; tls.out1 ~> o1 - o2 <~ tls.out2; tls.in2 <~ s - ClosedShape - }) - .run() - the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) - the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) - Thread.sleep(500) - val pub = TestPublisher.probe() - pub.subscribe(sub) - pub.expectSubscription().expectCancellation() - } - - "complete if TLS connection is truncated" in assertAllStagesStopped { - - val ks = KillSwitches.shared("ks") - - val scenario = SingleBytes - - val outFlow = { - val terminator = BidiFlow.fromFlows(Flow[ByteString], ks.flow[ByteString]) - clientTls(scenario.leftClosing) - .atop(terminator) - .atop(serverTls(scenario.rightClosing).reversed) - .join(debug.via(scenario.flow)) - .via(debug) + object SingleBytes extends PayloadScenario { + val str = "0123456789" + def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) } - val inFlow = Flow[SslTlsInbound] - .collect { case SessionBytes(_, b) => b } - .scan(ByteString.empty)(_ ++ _) - .via(new Timeout(6.seconds.dilated)) - .dropWhile(_.size < scenario.output.size) + object MediumMessages extends PayloadScenario { + val strs = "0123456789".map(d => d.toString * (rnd.nextInt(9000) + 1000)) + def inputs = strs.map(s => SendBytes(ByteString(s))) + def output = ByteString(strs.foldRight("")(_ ++ _)) + } - val f = - Source(scenario.inputs) - .via(outFlow) - .via(inFlow) - .map(result => { - ks.shutdown(); result - }) - .runWith(Sink.last) + object LargeMessages extends PayloadScenario { + // TLS max packet size is 16384 bytes + val strs = "0123456789".map(d => d.toString * (rnd.nextInt(9000) + 17000)) + def inputs = strs.map(s => SendBytes(ByteString(s))) + def output = ByteString(strs.foldRight("")(_ ++ _)) + } - Await.result(f, 8.second.dilated).utf8String should be(scenario.output.utf8String) - } + object EmptyBytesFirst extends PayloadScenario { + def inputs = List(ByteString.empty, ByteString("hello")).map(SendBytes) + def output = ByteString("hello") + } - "verify hostname" in assertAllStagesStopped { - def run(hostName: String): Future[akka.Done] = { - val rhs = Flow[SslTlsInbound].map { - case SessionTruncated => SendBytes(ByteString.empty) + object EmptyBytesInTheMiddle extends PayloadScenario { + def inputs = List(ByteString("hello"), ByteString.empty, ByteString(" world")).map(SendBytes) + def output = ByteString("hello world") + } + + object EmptyBytesLast extends PayloadScenario { + def inputs = List(ByteString("hello"), ByteString.empty).map(SendBytes) + def output = ByteString("hello") + } + + object CompletedImmediately extends PayloadScenario { + override def inputs: immutable.Seq[SslTlsOutbound] = Nil + override def output = ByteString.empty + + override def leftClosing: TLSClosing = EagerClose + override def rightClosing: TLSClosing = EagerClose + } + + // this demonstrates that cancellation is ignored so that the five results make it back + object CancellingRHS extends PayloadScenario { + override def flow = + Flow[SslTlsInbound] + .mapConcat { + case SessionTruncated => SessionTruncated :: Nil + case SessionBytes(s, bytes) => bytes.map(b => SessionBytes(s, ByteString(b))) + } + .take(5) + .mapAsync(5)(x => later(500.millis, system.scheduler)(Future.successful(x))) + .via(super.flow) + override def rightClosing = IgnoreCancel + + val str = "abcdef" * 100 + def inputs = str.map(send) + def output = ByteString(str.take(5)) + } + + object CancellingRHSIgnoresBoth extends PayloadScenario { + override def flow = + Flow[SslTlsInbound] + .mapConcat { + case SessionTruncated => SessionTruncated :: Nil + case SessionBytes(s, bytes) => bytes.map(b => SessionBytes(s, ByteString(b))) + } + .take(5) + .mapAsync(5)(x => later(500.millis, system.scheduler)(Future.successful(x))) + .via(super.flow) + override def rightClosing = IgnoreBoth + val str = "abcdef" * 100 + def inputs = str.map(send) + def output = ByteString(str.take(5)) + } + + object LHSIgnoresBoth extends PayloadScenario { + override def leftClosing = IgnoreBoth + val str = "0123456789" + def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) + } + + object BothSidesIgnoreBoth extends PayloadScenario { + override def leftClosing = IgnoreBoth + override def rightClosing = IgnoreBoth + val str = "0123456789" + def inputs = str.map(ch => SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) + } + + object SessionRenegotiationBySender extends PayloadScenario { + def inputs = List(send("hello"), NegotiateNewSession, send("world")) + def output = ByteString("helloNEWSESSIONworld") + } + + // difference is that the RHS engine will now receive the handshake while trying to send + object SessionRenegotiationByReceiver extends PayloadScenario { + val str = "abcdef" * 100 + def inputs = str.map(send) ++ Seq(NegotiateNewSession) ++ "hello world".map(send) + def output = ByteString(str + "NEWSESSIONhello world") + } + + val logCipherSuite = Flow[SslTlsInbound].map { + var session: SSLSession = null + def setSession(s: SSLSession) = { + session = s + system.log.debug(s"new session: $session (${session.getId.mkString(",")})") + } + + { + case SessionTruncated => SendBytes(ByteString("TRUNCATED")) + case SessionBytes(s, b) if s != session => + setSession(s) + SendBytes(ByteString(s.getCipherSuite) ++ b) case SessionBytes(_, b) => SendBytes(b) } - val clientTls = TLS( - () => createSSLEngine2(sslContext, Client, hostnameVerification = true, hostInfo = Some((hostName, 80))), - EagerClose) - - val flow = clientTls.atop(serverTls(EagerClose).reversed).join(rhs) - - Source.single(SendBytes(ByteString.empty)).via(flow).runWith(Sink.ignore) - } - Await.result(run("akka-remote"), 3.seconds) // CN=akka-remote - val cause = intercept[Exception] { - Await.result(run("unknown.example.org"), 3.seconds) } - val rootCause = - if (JavaVersion.majorVersion >= 11) { - cause.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem - cause.getCause - } else { - cause.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem - val cause2 = cause.getCause - cause2.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem - cause2.getCause + object SessionRenegotiationFirstOne extends PayloadScenario { + override def flow = logCipherSuite + def inputs = NegotiateNewSession.withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil + def output = ByteString("TLS_RSA_WITH_AES_128_CBC_SHAhello") + } + + object SessionRenegotiationFirstTwo extends PayloadScenario { + override def flow = logCipherSuite + def inputs = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil + def output = ByteString("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHAhello") + } + + val scenarios = + Seq( + SingleBytes, + MediumMessages, + LargeMessages, + EmptyBytesFirst, + EmptyBytesInTheMiddle, + EmptyBytesLast, + CompletedImmediately, + CancellingRHS, + CancellingRHSIgnoresBoth, + LHSIgnoresBoth, + BothSidesIgnoreBoth) ++ + (if (protocol == "TLSv1.2") + Seq( + SessionRenegotiationBySender, + SessionRenegotiationByReceiver, + SessionRenegotiationFirstOne, + SessionRenegotiationFirstTwo) + else // TLSv1.3 doesn't support renegotiation + Nil) + + for { + commPattern <- communicationPatterns + scenario <- scenarios + } { + s"work in mode ${commPattern.name} while sending ${scenario.name}" in assertAllStagesStopped { + val onRHS = debug.via(scenario.flow) + val output = + Source(scenario.inputs) + .via(commPattern.decorateFlow(scenario.leftClosing, scenario.rightClosing, onRHS)) + .via(new SimpleLinearGraphStage[SslTlsInbound] { + override def createLogic(inheritedAttributes: Attributes) = + new GraphStageLogic(shape) with InHandler with OutHandler { + setHandlers(in, out, this) + + override def onPush() = push(out, grab(in)) + override def onPull() = pull(in) + + override def onDownstreamFinish(cause: Throwable) = { + system.log.debug(s"me cancelled, cause {}", cause) + completeStage() + } + } + }) + .via(debug) + .collect { case SessionBytes(_, b) => b } + .scan(ByteString.empty)(_ ++ _) + .filter(_.nonEmpty) + .via(new Timeout(10.seconds)) + .dropWhile(_.size < scenario.output.size) + .runWith(Sink.headOption) + + Await.result(output, 12.seconds).getOrElse(ByteString.empty).utf8String should be(scenario.output.utf8String) + + commPattern.cleanup() } - rootCause.getClass should ===(classOf[CertificateException]) - rootCause.getMessage should ===("No name matching unknown.example.org found") - } + } + "emit an error if the TLS handshake fails certificate checks" in assertAllStagesStopped { + val getError = Flow[SslTlsInbound] + .map[Either[SslTlsInbound, SSLException]](i => Left(i)) + .recover { case e: SSLException => Right(e) } + .collect { case Right(e) => e } + .toMat(Sink.head)(Keep.right) + + val simple = Flow.fromSinkAndSourceMat(getError, Source.maybe[SslTlsOutbound])(Keep.left) + + // The creation of actual TCP connections is necessary. It is the easiest way to decouple the client and server + // under error conditions, and has the bonus of matching most actual SSL deployments. + val (server, serverErr) = Tcp() + .bind("localhost", 0) + .mapAsync(1)(c => + c.flow.joinMat(serverTls(IgnoreBoth).reversed.joinMat(simple)(Keep.right))(Keep.right).run()) + .toMat(Sink.head)(Keep.both) + .run() + + val clientErr = simple + .join(badClientTls(IgnoreBoth)) + .join(Tcp().outgoingConnection(Await.result(server, 1.second).localAddress)) + .run() + + Await.result(serverErr, 1.second).getMessage should include("certificate_unknown") + val clientErrText = Await.result(clientErr, 1.second).getMessage + if (JavaVersion.majorVersion >= 11) + clientErrText should include("unable to find valid certification path to requested target") + else + clientErrText should equal("General SSLEngine problem") + } + + "reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped { + val ex = new Exception("hello") + val (sub, out1, out2) = + RunnableGraph + .fromGraph( + GraphDSL.create(Source.asSubscriber[SslTlsOutbound], Sink.head[ByteString], Sink.head[SslTlsInbound])( + (_, _, _)) { implicit b => (s, o1, o2) => + val tls = b.add(clientTls(EagerClose)) + s ~> tls.in1 + tls.out1 ~> o1 + o2 <~ tls.out2 + tls.in2 <~ Source.failed(ex) + ClosedShape + }) + .run() + the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) + the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) + Thread.sleep(500) + val pub = TestPublisher.probe() + pub.subscribe(sub) + pub.expectSubscription().expectCancellation() + } + + "reliably cancel subscriptions when UserIn fails early" in assertAllStagesStopped { + val ex = new Exception("hello") + val (sub, out1, out2) = + RunnableGraph + .fromGraph( + GraphDSL.create(Source.asSubscriber[ByteString], Sink.head[ByteString], Sink.head[SslTlsInbound])( + (_, _, _)) { implicit b => (s, o1, o2) => + val tls = b.add(clientTls(EagerClose)) + Source.failed[SslTlsOutbound](ex) ~> tls.in1 + tls.out1 ~> o1 + o2 <~ tls.out2 + tls.in2 <~ s + ClosedShape + }) + .run() + the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) + the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) + Thread.sleep(500) + val pub = TestPublisher.probe() + pub.subscribe(sub) + pub.expectSubscription().expectCancellation() + } + + "complete if TLS connection is truncated" in assertAllStagesStopped { + + val ks = KillSwitches.shared("ks") + + val scenario = SingleBytes + + val outFlow = { + val terminator = BidiFlow.fromFlows(Flow[ByteString], ks.flow[ByteString]) + clientTls(scenario.leftClosing) + .atop(terminator) + .atop(serverTls(scenario.rightClosing).reversed) + .join(debug.via(scenario.flow)) + .via(debug) + } + + val inFlow = Flow[SslTlsInbound] + .collect { case SessionBytes(_, b) => b } + .scan(ByteString.empty)(_ ++ _) + .via(new Timeout(6.seconds.dilated)) + .dropWhile(_.size < scenario.output.size) + + val f = + Source(scenario.inputs) + .via(outFlow) + .via(inFlow) + .map(result => { + ks.shutdown() + result + }) + .runWith(Sink.last) + Await.result(f, 8.second.dilated).utf8String should be(scenario.output.utf8String) + } + + "verify hostname" in assertAllStagesStopped { + def run(hostName: String): Future[akka.Done] = { + val rhs = Flow[SslTlsInbound].map { + case SessionTruncated => SendBytes(ByteString.empty) + case SessionBytes(_, b) => SendBytes(b) + } + val clientTls = TLS( + () => createSSLEngine2(sslContext, Client, hostnameVerification = true, hostInfo = Some((hostName, 80))), + EagerClose) + + val flow = clientTls.atop(serverTls(EagerClose).reversed).join(rhs) + + Source.single(SendBytes(ByteString.empty)).via(flow).runWith(Sink.ignore) + } + + Await.result(run("akka-remote"), 3.seconds) // CN=akka-remote + val cause = intercept[Exception] { + Await.result(run("unknown.example.org"), 3.seconds) + } + + val rootCause = + if (JavaVersion.majorVersion >= 11) { + cause.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem + cause.getCause + } else { + cause.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem + val cause2 = cause.getCause + cause2.getClass should ===(classOf[SSLHandshakeException]) //General SSLEngine problem + cause2.getCause + } + rootCause.getClass should ===(classOf[CertificateException]) + rootCause.getMessage should ===("No name matching unknown.example.org found") + } + } } "A SslTlsPlacebo" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/Timers.scala b/akka-stream/src/main/scala/akka/stream/impl/Timers.scala index 917fbfcb5a..525bb9087d 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Timers.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Timers.scala @@ -255,20 +255,27 @@ import akka.stream.stage._ if (isClosed(in)) completeStage() else pull(in) } else { - val time = System.nanoTime - if (nextDeadline - time < 0) { - nextDeadline = time + timeout.toNanos + val now = System.nanoTime() + // Idle timeout triggered a while ago and we were just waiting for pull. + // In the case of now == deadline, the deadline has not passed strictly, but scheduling another thunk + // for that seems wasteful. + if (now - nextDeadline >= 0) { + nextDeadline = now + timeout.toNanos push(out, inject()) - } else scheduleOnce(GraphStageLogicTimer, FiniteDuration(nextDeadline - time, TimeUnit.NANOSECONDS)) + } else + scheduleOnce(GraphStageLogicTimer, FiniteDuration(nextDeadline - now, TimeUnit.NANOSECONDS)) } } override protected def onTimer(timerKey: Any): Unit = { - val time = System.nanoTime - if ((nextDeadline - time < 0) && isAvailable(out)) { - push(out, inject()) - nextDeadline = time + timeout.toNanos - } + val now = System.nanoTime() + // Timer is reliably cancelled if a regular element arrives first. Scheduler rather schedules too late + // than too early so the deadline must have passed at this time. + assert( + now - nextDeadline >= 0, + s"Timer should have triggered only after deadline but now is $now and deadline was $nextDeadline diff ${now - nextDeadline}.") + push(out, inject()) + nextDeadline = now + timeout.toNanos } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/FutureFlow.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/FutureFlow.scala index 8dfd00b1b7..c3c99cc8c1 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/FutureFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/FutureFlow.scala @@ -12,6 +12,7 @@ import akka.stream.stage.{ GraphStageLogic, GraphStageWithMaterializedValue, InH import akka.util.OptionVal import scala.concurrent.{ Future, Promise } +import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } @InternalApi private[akka] final class FutureFlow[In, Out, M](futureFlow: Future[Flow[In, Out, M]]) @@ -44,9 +45,8 @@ import scala.util.{ Failure, Success, Try } } override def postStop(): Unit = { - if (!innerMatValue.isCompleted) { + if (!innerMatValue.isCompleted) innerMatValue.failure(new AbruptStageTerminationException(this)) - } } object Initializing extends InHandler with OutHandler { @@ -93,49 +93,37 @@ import scala.util.{ Failure, Success, Try } subSource.setHandler { new OutHandler { override def onPull(): Unit = if (!isClosed(in)) tryPull(in) - override def onDownstreamFinish(cause: Throwable): Unit = if (!isClosed(in)) cancel(in, cause) } } subSink.setHandler { new InHandler { override def onPush(): Unit = push(out, subSink.grab()) - override def onUpstreamFinish(): Unit = complete(out) - override def onUpstreamFailure(ex: Throwable): Unit = fail(out, ex) } } - Try { - Source.fromGraph(subSource.source).viaMat(flow)(Keep.right).to(subSink.sink).run()(subFusingMaterializer) - } match { - case Success(matVal) => - innerMatValue.success(matVal) - upstreamFailure match { - case OptionVal.Some(ex) => - subSource.fail(ex) - case OptionVal.None => - if (isClosed(in)) - subSource.complete() - } - downstreamCause match { - case OptionVal.Some(cause) => - subSink.cancel(cause) - case OptionVal.None => - if (isAvailable(out)) subSink.pull() - } - setHandlers(in, out, new InHandler with OutHandler { - override def onPull(): Unit = subSink.pull() - - override def onDownstreamFinish(cause: Throwable): Unit = subSink.cancel(cause) - - override def onPush(): Unit = subSource.push(grab(in)) - - override def onUpstreamFinish(): Unit = subSource.complete() - - override def onUpstreamFailure(ex: Throwable): Unit = subSource.fail(ex) - }) - case Failure(ex) => + try { + val matVal = + Source.fromGraph(subSource.source).viaMat(flow)(Keep.right).to(subSink.sink).run()(subFusingMaterializer) + innerMatValue.success(matVal) + upstreamFailure match { + case OptionVal.Some(ex) => subSource.fail(ex) + case OptionVal.None => if (isClosed(in)) subSource.complete() + } + downstreamCause match { + case OptionVal.Some(cause) => subSink.cancel(cause) + case OptionVal.None => if (isAvailable(out)) subSink.pull() + } + setHandlers(in, out, new InHandler with OutHandler { + override def onPull(): Unit = subSink.pull() + override def onDownstreamFinish(cause: Throwable): Unit = subSink.cancel(cause) + override def onPush(): Unit = subSource.push(grab(in)) + override def onUpstreamFinish(): Unit = subSource.complete() + override def onUpstreamFailure(ex: Throwable): Unit = subSource.fail(ex) + }) + } catch { + case NonFatal(ex) => innerMatValue.failure(new NeverMaterializedException(ex)) failStage(ex) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TLSActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TLSActor.scala index dd5a9ce955..ca75256f65 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TLSActor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TLSActor.scala @@ -269,7 +269,7 @@ import akka.util.ByteString } def completeOrFlush(): Unit = - if (engine.isOutboundDone) nextPhase(completedPhase) + if (engine.isOutboundDone || (engine.isInboundDone && userInChoppingBlock.isEmpty)) nextPhase(completedPhase) else nextPhase(flushingOutbound) private def doInbound(isOutboundClosed: Boolean, inboundState: TransferState): Boolean = @@ -395,7 +395,9 @@ import akka.util.ByteString result.getStatus match { case OK => result.getHandshakeStatus match { - case NEED_WRAP => flushToUser() + case NEED_WRAP => + flushToUser() + transportInChoppingBlock.putBack(transportInBuffer) case FINISHED => flushToUser() handshakeFinished() @@ -406,8 +408,7 @@ import akka.util.ByteString } case CLOSED => flushToUser() - if (engine.isOutboundDone) nextPhase(completedPhase) - else nextPhase(flushingOutbound) + completeOrFlush() case BUFFER_UNDERFLOW => flushToUser() case BUFFER_OVERFLOW => diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/TLS.scala b/akka-stream/src/main/scala/akka/stream/javadsl/TLS.scala index 2b456fbb92..e0bdf57aec 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/TLS.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/TLS.scala @@ -36,7 +36,7 @@ import akka.util.ByteString * * '''IMPORTANT NOTE''' * - * The TLS specification does not permit half-closing of the user data session + * The TLS specification until version 1.2 did not permit half-closing of the user data session * that it transports—to be precise a half-close will always promptly lead to a * full close. This means that canceling the plaintext output or completing the * plaintext input of the SslTls operator will lead to full termination of the @@ -50,7 +50,8 @@ import akka.util.ByteString * order to terminate the connection the client will then need to cancel the * plaintext output as soon as all expected bytes have been received. When * ignoring both types of events the operator will shut down once both events have - * been received. See also [[TLSClosing]]. + * been received. See also [[TLSClosing]]. For now, half-closing is also not + * supported with TLS 1.3 where the spec allows it. */ object TLS { diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/TLS.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/TLS.scala index 381618fa35..410f4b7ac8 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/TLS.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/TLS.scala @@ -37,7 +37,7 @@ import akka.util.ByteString * * '''IMPORTANT NOTE''' * - * The TLS specification does not permit half-closing of the user data session + * The TLS specification until version 1.2 did not permit half-closing of the user data session * that it transports—to be precise a half-close will always promptly lead to a * full close. This means that canceling the plaintext output or completing the * plaintext input of the SslTls operator will lead to full termination of the @@ -51,7 +51,8 @@ import akka.util.ByteString * order to terminate the connection the client will then need to cancel the * plaintext output as soon as all expected bytes have been received. When * ignoring both types of events the operator will shut down once both events have - * been received. See also [[TLSClosing]]. + * been received. See also [[TLSClosing]]. For now, half-closing is also not + * supported with TLS 1.3 where the spec allows it. */ object TLS {