Merge branch 'master' into wip-scala210M4-√

This commit is contained in:
Viktor Klang 2012-06-21 12:06:30 +02:00
commit f3c693a245
24 changed files with 570 additions and 385 deletions

View file

@ -43,9 +43,23 @@ akka {
# a quick detection in the event of a real crash. Conversely, a high
# threshold generates fewer mistakes but needs more time to detect
# actual crashes
threshold = 8
threshold = 8.0
implementation-class = ""
# Minimum standard deviation to use for the normal distribution in
# AccrualFailureDetector. Too low standard deviation might result in
# too much sensitivity for sudden, but normal, deviations in heartbeat
# inter arrival times.
min-std-deviation = 100 ms
# Number of potentially lost/delayed heartbeats that will be
# accepted before considering it to be an anomaly.
# It is a factor of heartbeat-interval.
# This margin is important to be able to survive sudden, occasional,
# pauses in heartbeat arrivals, due to for example garbage collect or
# network drop.
acceptable-heartbeat-pause = 3s
implementation-class = "akka.cluster.AccrualFailureDetector"
max-sample-size = 1000
}

View file

@ -7,50 +7,103 @@ package akka.cluster
import akka.actor.{ ActorSystem, Address, ExtendedActorSystem }
import akka.remote.RemoteActorRefProvider
import akka.event.Logging
import scala.collection.immutable.Map
import scala.annotation.tailrec
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.TimeUnit.NANOSECONDS
import akka.util.Duration
import akka.util.duration._
object AccrualFailureDetector {
private def realClock: () Long = () NANOSECONDS.toMillis(System.nanoTime)
}
/**
* Implementation of 'The Phi Accrual Failure Detector' by Hayashibara et al. as defined in their paper:
* [http://ddg.jaist.ac.jp/pub/HDY+04.pdf]
* <p/>
* A low threshold is prone to generate many wrong suspicions but ensures a quick detection in the event
* of a real crash. Conversely, a high threshold generates fewer mistakes but needs more time to detect
* actual crashes
* <p/>
* Default threshold is 8, but can be configured in the Akka config.
*
* The suspicion level of failure is given by a value called φ (phi).
* The basic idea of the φ failure detector is to express the value of φ on a scale that
* is dynamically adjusted to reflect current network conditions. A configurable
* threshold is used to decide if φ is considered to be a failure.
*
* The value of φ is calculated as:
*
* {{{
* φ = -log10(1 - F(timeSinceLastHeartbeat)
* }}}
* where F is the cumulative distribution function of a normal distribution with mean
* and standard deviation estimated from historical heartbeat inter-arrival times.
*
*
* @param system Belongs to the [[akka.actor.ActorSystem]]. Used for logging.
*
* @param threshold A low threshold is prone to generate many wrong suspicions but ensures a quick detection in the event
* of a real crash. Conversely, a high threshold generates fewer mistakes but needs more time to detect
* actual crashes
*
* @param maxSampleSize Number of samples to use for calculation of mean and standard deviation of
* inter-arrival times.
*
* @param minStdDeviation Minimum standard deviation to use for the normal distribution used when calculating phi.
* Too low standard deviation might result in too much sensitivity for sudden, but normal, deviations
* in heartbeat inter arrival times.
*
* @param acceptableHeartbeatPause Duration corresponding to number of potentially lost/delayed
* heartbeats that will be accepted before considering it to be an anomaly.
* This margin is important to be able to survive sudden, occasional, pauses in heartbeat
* arrivals, due to for example garbage collect or network drop.
*
* @param firstHeartbeatEstimate Bootstrap the stats with heartbeats that corresponds to
* to this duration, with a with rather high standard deviation (since environment is unknown
* in the beginning)
*
* @clock The clock, returning current time in milliseconds, but can be faked for testing
* purposes. It is only used for measuring intervals (duration).
*
*/
class AccrualFailureDetector(
val system: ActorSystem,
val threshold: Int = 8,
val maxSampleSize: Int = 1000,
val timeMachine: () Long = System.currentTimeMillis) extends FailureDetector {
val threshold: Double,
val maxSampleSize: Int,
val minStdDeviation: Duration,
val acceptableHeartbeatPause: Duration,
val firstHeartbeatEstimate: Duration,
val clock: () Long) extends FailureDetector {
import AccrualFailureDetector._
/**
* Constructor that picks configuration from the settings.
*/
def this(
system: ActorSystem,
settings: ClusterSettings,
timeMachine: () Long = System.currentTimeMillis) =
clock: () Long = AccrualFailureDetector.realClock) =
this(
system,
settings.FailureDetectorThreshold,
settings.FailureDetectorMaxSampleSize,
timeMachine)
private final val PhiFactor = 1.0 / math.log(10.0)
settings.FailureDetectorAcceptableHeartbeatPause,
settings.FailureDetectorMinStdDeviation,
// we use a conservative estimate for the first heartbeat because
// gossip needs to spread back to the joining node before the
// first real heartbeat is sent. Initial heartbeat is added when joining.
// FIXME this can be changed to HeartbeatInterval when ticket #2249 is fixed
settings.GossipInterval * 3 + settings.HeartbeatInterval,
clock)
private val log = Logging(system, "FailureDetector")
/**
* Holds the failure statistics for a specific node Address.
*/
private case class FailureStats(mean: Double = 0.0, variance: Double = 0.0, deviation: Double = 0.0)
// guess statistics for first heartbeat,
// important so that connections with only one heartbeat becomes unavailble
private val failureStatsFirstHeartbeat = FailureStats(mean = 1000.0)
// important so that connections with only one heartbeat becomes unavailable
private val firstHeartbeat: HeartbeatHistory = {
// bootstrap with 2 entries with rather high standard deviation
val mean = firstHeartbeatEstimate.toMillis
val stdDeviation = mean / 4
HeartbeatHistory(maxSampleSize) :+ (mean - stdDeviation) :+ (mean + stdDeviation)
}
private val acceptableHeartbeatPauseMillis = acceptableHeartbeatPause.toMillis
/**
* Implement using optimistic lockless concurrency, all state is represented
@ -58,8 +111,7 @@ class AccrualFailureDetector(
*/
private case class State(
version: Long = 0L,
failureStats: Map[Address, FailureStats] = Map.empty[Address, FailureStats],
intervalHistory: Map[Address, IndexedSeq[Long]] = Map.empty[Address, IndexedSeq[Long]],
history: Map[Address, HeartbeatHistory] = Map.empty,
timestamps: Map[Address, Long] = Map.empty[Address, Long],
explicitRemovals: Set[Address] = Set.empty[Address])
@ -78,95 +130,76 @@ class AccrualFailureDetector(
final def heartbeat(connection: Address) {
log.debug("Heartbeat from connection [{}] ", connection)
val timestamp = clock()
val oldState = state.get
val latestTimestamp = oldState.timestamps.get(connection)
if (latestTimestamp.isEmpty) {
// this is heartbeat from a new connection
// add starter records for this new connection
val newState = oldState copy (
version = oldState.version + 1,
failureStats = oldState.failureStats + (connection -> failureStatsFirstHeartbeat),
intervalHistory = oldState.intervalHistory + (connection -> IndexedSeq.empty[Long]),
timestamps = oldState.timestamps + (connection -> timeMachine()),
explicitRemovals = oldState.explicitRemovals - connection)
// if we won the race then update else try again
if (!state.compareAndSet(oldState, newState)) heartbeat(connection) // recur
} else {
// this is a known connection
val timestamp = timeMachine()
val interval = timestamp - latestTimestamp.get
val newIntervalsForConnection = (oldState.intervalHistory.get(connection) match {
case Some(history) if history.size >= maxSampleSize
// reached max history, drop first interval
history drop 1
case Some(history) history
case _ IndexedSeq.empty[Long]
}) :+ interval
val newFailureStats = {
val newMean: Double = newIntervalsForConnection.sum.toDouble / newIntervalsForConnection.size
val oldConnectionFailureStats = oldState.failureStats.get(connection).getOrElse {
throw new IllegalStateException("Can't calculate new failure statistics due to missing heartbeat history")
}
val deviationSum = (0.0d /: newIntervalsForConnection) { (mean, interval)
mean + interval.toDouble - newMean
}
val newVariance: Double = deviationSum / newIntervalsForConnection.size
val newDeviation: Double = math.sqrt(newVariance)
val newFailureStats = oldConnectionFailureStats copy (mean = newMean, deviation = newDeviation, variance = newVariance)
oldState.failureStats + (connection -> newFailureStats)
}
val newState = oldState copy (version = oldState.version + 1,
failureStats = newFailureStats,
intervalHistory = oldState.intervalHistory + (connection -> newIntervalsForConnection),
timestamps = oldState.timestamps + (connection -> timestamp), // record new timestamp,
explicitRemovals = oldState.explicitRemovals - connection)
// if we won the race then update else try again
if (!state.compareAndSet(oldState, newState)) heartbeat(connection) // recur
val newHistory = oldState.timestamps.get(connection) match {
case None
// this is heartbeat from a new connection
// add starter records for this new connection
firstHeartbeat
case Some(latestTimestamp)
// this is a known connection
val interval = timestamp - latestTimestamp
oldState.history(connection) :+ interval
}
val newState = oldState copy (version = oldState.version + 1,
history = oldState.history + (connection -> newHistory),
timestamps = oldState.timestamps + (connection -> timestamp), // record new timestamp,
explicitRemovals = oldState.explicitRemovals - connection)
// if we won the race then update else try again
if (!state.compareAndSet(oldState, newState)) heartbeat(connection) // recur
}
/**
* Calculates how likely it is that the connection has failed.
* <p/>
* The suspicion level of the accrual failure detector.
*
* If a connection does not have any records in failure detector then it is
* considered healthy.
* <p/>
* Implementations of 'Cumulative Distribution Function' for Exponential Distribution.
* For a discussion on the math read [https://issues.apache.org/jira/browse/CASSANDRA-2597].
*/
def phi(connection: Address): Double = {
val oldState = state.get
val oldTimestamp = oldState.timestamps.get(connection)
val phi =
// if connection has been removed explicitly
if (oldState.explicitRemovals.contains(connection)) Double.MaxValue
else if (oldTimestamp.isEmpty) 0.0 // treat unmanaged connections, e.g. with zero heartbeats, as healthy connections
else {
val timestampDiff = timeMachine() - oldTimestamp.get
// if connection has been removed explicitly
if (oldState.explicitRemovals.contains(connection)) Double.MaxValue
else if (oldTimestamp.isEmpty) 0.0 // treat unmanaged connections, e.g. with zero heartbeats, as healthy connections
else {
val timeDiff = clock() - oldTimestamp.get
val mean = oldState.failureStats.get(connection) match {
case Some(FailureStats(mean, _, _)) mean
case _ throw new IllegalStateException("Can't calculate Failure Detector Phi value for a node that have no heartbeat history")
}
val history = oldState.history(connection)
val mean = history.mean
val stdDeviation = ensureValidStdDeviation(history.stdDeviation)
if (mean == 0.0) 0.0
else PhiFactor * timestampDiff / mean
}
val φ = phi(timeDiff, mean + acceptableHeartbeatPauseMillis, stdDeviation)
log.debug("Phi value [{}] and threshold [{}] for connection [{}] ", phi, threshold, connection)
phi
// FIXME change to debug log level, when failure detector is stable
if (φ > 1.0) log.info("Phi value [{}] for connection [{}], after [{} ms], based on [{}]",
φ, connection, timeDiff, "N(" + mean + ", " + stdDeviation + ")")
φ
}
}
private[cluster] def phi(timeDiff: Long, mean: Double, stdDeviation: Double): Double = {
val cdf = cumulativeDistributionFunction(timeDiff, mean, stdDeviation)
-math.log10(1.0 - cdf)
}
private val minStdDeviationMillis = minStdDeviation.toMillis
private def ensureValidStdDeviation(stdDeviation: Double): Double = math.max(stdDeviation, minStdDeviationMillis)
/**
* Cumulative distribution function for N(mean, stdDeviation) normal distribution.
* This is an approximation defined in β Mathematics Handbook.
*/
private[cluster] def cumulativeDistributionFunction(x: Double, mean: Double, stdDeviation: Double): Double = {
val y = (x - mean) / stdDeviation
// Cumulative distribution function for N(0, 1)
1.0 / (1.0 + math.exp(-y * (1.5976 + 0.070566 * y * y)))
}
/**
@ -177,10 +210,9 @@ class AccrualFailureDetector(
log.debug("Remove connection [{}] ", connection)
val oldState = state.get
if (oldState.failureStats.contains(connection)) {
if (oldState.history.contains(connection)) {
val newState = oldState copy (version = oldState.version + 1,
failureStats = oldState.failureStats - connection,
intervalHistory = oldState.intervalHistory - connection,
history = oldState.history - connection,
timestamps = oldState.timestamps - connection,
explicitRemovals = oldState.explicitRemovals + connection)
@ -189,3 +221,66 @@ class AccrualFailureDetector(
}
}
}
private[cluster] object HeartbeatHistory {
/**
* Create an empty HeartbeatHistory, without any history.
* Can only be used as starting point for appending intervals.
* The stats (mean, variance, stdDeviation) are not defined for
* for empty HeartbeatHistory, i.e. throws AritmeticException.
*/
def apply(maxSampleSize: Int): HeartbeatHistory = HeartbeatHistory(
maxSampleSize = maxSampleSize,
intervals = IndexedSeq.empty,
intervalSum = 0L,
squaredIntervalSum = 0L)
}
/**
* Holds the heartbeat statistics for a specific node Address.
* It is capped by the number of samples specified in `maxSampleSize`.
*
* The stats (mean, variance, stdDeviation) are not defined for
* for empty HeartbeatHistory, i.e. throws AritmeticException.
*/
private[cluster] case class HeartbeatHistory private (
maxSampleSize: Int,
intervals: IndexedSeq[Long],
intervalSum: Long,
squaredIntervalSum: Long) {
if (maxSampleSize < 1)
throw new IllegalArgumentException("maxSampleSize must be >= 1, got [%s]" format maxSampleSize)
if (intervalSum < 0L)
throw new IllegalArgumentException("intervalSum must be >= 0, got [%s]" format intervalSum)
if (squaredIntervalSum < 0L)
throw new IllegalArgumentException("squaredIntervalSum must be >= 0, got [%s]" format squaredIntervalSum)
def mean: Double = intervalSum.toDouble / intervals.size
def variance: Double = (squaredIntervalSum.toDouble / intervals.size) - (mean * mean)
def stdDeviation: Double = math.sqrt(variance)
@tailrec
final def :+(interval: Long): HeartbeatHistory = {
if (intervals.size < maxSampleSize)
HeartbeatHistory(
maxSampleSize,
intervals = intervals :+ interval,
intervalSum = intervalSum + interval,
squaredIntervalSum = squaredIntervalSum + pow2(interval))
else
dropOldest :+ interval // recur
}
private def dropOldest: HeartbeatHistory = HeartbeatHistory(
maxSampleSize,
intervals = intervals drop 1,
intervalSum = intervalSum - intervals.head,
squaredIntervalSum = squaredIntervalSum - pow2(intervals.head))
private def pow2(x: Long) = x * x
}

View file

@ -403,14 +403,12 @@ object Cluster extends ExtensionId[Cluster] with ExtensionIdProvider {
override def createExtension(system: ExtendedActorSystem): Cluster = {
val clusterSettings = new ClusterSettings(system.settings.config, system.name)
val failureDetector = clusterSettings.FailureDetectorImplementationClass match {
case None new AccrualFailureDetector(system, clusterSettings)
case Some(fqcn)
system.dynamicAccess.createInstanceFor[FailureDetector](
fqcn, Seq((classOf[ActorSystem], system), (classOf[ClusterSettings], clusterSettings))) match {
case Right(fd) fd
case Left(e) throw new ConfigurationException("Could not create custom failure detector [" + fqcn + "] due to:" + e.toString)
}
val failureDetector = {
import clusterSettings.{ FailureDetectorImplementationClass fqcn }
system.dynamicAccess.createInstanceFor[FailureDetector](
fqcn, Seq(classOf[ActorSystem] -> system, classOf[ClusterSettings] -> clusterSettings)).fold(
e throw new ConfigurationException("Could not create custom failure detector [" + fqcn + "] due to:" + e.toString),
identity)
}
new Cluster(system, failureDetector)

View file

@ -13,12 +13,15 @@ import akka.actor.AddressFromURIString
class ClusterSettings(val config: Config, val systemName: String) {
import config._
final val FailureDetectorThreshold = getInt("akka.cluster.failure-detector.threshold")
final val FailureDetectorThreshold = getDouble("akka.cluster.failure-detector.threshold")
final val FailureDetectorMaxSampleSize = getInt("akka.cluster.failure-detector.max-sample-size")
final val FailureDetectorImplementationClass: Option[String] = getString("akka.cluster.failure-detector.implementation-class") match {
case "" None
case fqcn Some(fqcn)
}
final val FailureDetectorImplementationClass = getString("akka.cluster.failure-detector.implementation-class")
final val FailureDetectorMinStdDeviation: Duration =
Duration(getMilliseconds("akka.cluster.failure-detector.min-std-deviation"), MILLISECONDS)
final val FailureDetectorAcceptableHeartbeatPause: Duration =
Duration(getMilliseconds("akka.cluster.failure-detector.acceptable-heartbeat-pause"), MILLISECONDS)
final val NodeToJoin: Option[Address] = getString("akka.cluster.node-to-join") match {
case "" None
case AddressFromURIString(addr) Some(addr)

View file

@ -9,7 +9,7 @@ import akka.remote.testkit.MultiNodeSpec
import akka.util.duration._
import akka.testkit._
object GossipingAccrualFailureDetectorMultiJvmSpec extends MultiNodeConfig {
object ClusterAccrualFailureDetectorMultiJvmSpec extends MultiNodeConfig {
val first = role("first")
val second = role("second")
val third = role("third")
@ -19,22 +19,22 @@ object GossipingAccrualFailureDetectorMultiJvmSpec extends MultiNodeConfig {
withFallback(MultiNodeClusterSpec.clusterConfig))
}
class GossipingWithAccrualFailureDetectorMultiJvmNode1 extends GossipingAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
class GossipingWithAccrualFailureDetectorMultiJvmNode2 extends GossipingAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
class GossipingWithAccrualFailureDetectorMultiJvmNode3 extends GossipingAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
class ClusterAccrualFailureDetectorMultiJvmNode1 extends ClusterAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
class ClusterAccrualFailureDetectorMultiJvmNode2 extends ClusterAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
class ClusterAccrualFailureDetectorMultiJvmNode3 extends ClusterAccrualFailureDetectorSpec with AccrualFailureDetectorStrategy
abstract class GossipingAccrualFailureDetectorSpec
extends MultiNodeSpec(GossipingAccrualFailureDetectorMultiJvmSpec)
abstract class ClusterAccrualFailureDetectorSpec
extends MultiNodeSpec(ClusterAccrualFailureDetectorMultiJvmSpec)
with MultiNodeClusterSpec {
import GossipingAccrualFailureDetectorMultiJvmSpec._
import ClusterAccrualFailureDetectorMultiJvmSpec._
"A Gossip-driven Failure Detector" must {
"A heartbeat driven Failure Detector" must {
"receive gossip heartbeats so that all member nodes in the cluster are marked 'available'" taggedAs LongRunningTest in {
"receive heartbeats so that all member nodes in the cluster are marked 'available'" taggedAs LongRunningTest in {
awaitClusterUp(first, second, third)
5.seconds.dilated.sleep // let them gossip
5.seconds.dilated.sleep // let them heartbeat
cluster.failureDetector.isAvailable(first) must be(true)
cluster.failureDetector.isAvailable(second) must be(true)
cluster.failureDetector.isAvailable(third) must be(true)
@ -47,9 +47,11 @@ abstract class GossipingAccrualFailureDetectorSpec
testConductor.shutdown(third, 0)
}
enterBarrier("third-shutdown")
runOn(first, second) {
// remaning nodes should detect failure...
awaitCond(!cluster.failureDetector.isAvailable(third), 10.seconds)
awaitCond(!cluster.failureDetector.isAvailable(third), 15.seconds)
// other connections still ok
cluster.failureDetector.isAvailable(first) must be(true)
cluster.failureDetector.isAvailable(second) must be(true)

View file

@ -55,7 +55,7 @@ trait AccrualFailureDetectorStrategy extends FailureDetectorStrategy { self: Mul
override val failureDetector: FailureDetector = new AccrualFailureDetector(system, new ClusterSettings(system.settings.config, system.name))
override def markNodeAsAvailable(address: Address): Unit = { /* no-op */ }
override def markNodeAsAvailable(address: Address): Unit = ()
override def markNodeAsUnavailable(address: Address): Unit = { /* no-op */ }
override def markNodeAsUnavailable(address: Address): Unit = ()
}

View file

@ -50,7 +50,7 @@ abstract class LeaderElectionSpec
assertLeaderIn(sortedRoles)
}
enterBarrier("after")
enterBarrier("after-1")
}
def shutdownLeaderAndVerifyNewLeader(alreadyShutdown: Int): Unit = {
@ -97,10 +97,12 @@ abstract class LeaderElectionSpec
"be able to 're-elect' a single leader after leader has left" taggedAs LongRunningTest in {
shutdownLeaderAndVerifyNewLeader(alreadyShutdown = 0)
enterBarrier("after-2")
}
"be able to 're-elect' a single leader after leader has left (again)" taggedAs LongRunningTest in {
shutdownLeaderAndVerifyNewLeader(alreadyShutdown = 1)
enterBarrier("after-3")
}
}
}

View file

@ -35,6 +35,8 @@ abstract class LeaderLeavingSpec
import LeaderLeavingMultiJvmSpec._
val leaderHandoffWaitingTime = 30.seconds.dilated
"A LEADER that is LEAVING" must {
"be moved to LEAVING, then to EXITING, then to REMOVED, then be shut down and then a new LEADER should be elected" taggedAs LongRunningTest in {
@ -62,19 +64,19 @@ abstract class LeaderLeavingSpec
enterBarrier("leader-left")
// verify that the LEADER is LEAVING
awaitCond(cluster.latestGossip.members.exists(m m.status == MemberStatus.Leaving && m.address == oldLeaderAddress)) // wait on LEAVING
awaitCond(cluster.latestGossip.members.exists(m m.status == MemberStatus.Leaving && m.address == oldLeaderAddress), leaderHandoffWaitingTime) // wait on LEAVING
// verify that the LEADER is EXITING
awaitCond(cluster.latestGossip.members.exists(m m.status == MemberStatus.Exiting && m.address == oldLeaderAddress)) // wait on EXITING
awaitCond(cluster.latestGossip.members.exists(m m.status == MemberStatus.Exiting && m.address == oldLeaderAddress), leaderHandoffWaitingTime) // wait on EXITING
// verify that the LEADER is no longer part of the 'members' set
awaitCond(cluster.latestGossip.members.forall(_.address != oldLeaderAddress))
awaitCond(cluster.latestGossip.members.forall(_.address != oldLeaderAddress), leaderHandoffWaitingTime)
// verify that the LEADER is not part of the 'unreachable' set
awaitCond(cluster.latestGossip.overview.unreachable.forall(_.address != oldLeaderAddress))
awaitCond(cluster.latestGossip.overview.unreachable.forall(_.address != oldLeaderAddress), leaderHandoffWaitingTime)
// verify that we have a new LEADER
awaitCond(cluster.leader != oldLeaderAddress)
awaitCond(cluster.leader != oldLeaderAddress, leaderHandoffWaitingTime)
}
enterBarrier("finished")

View file

@ -20,8 +20,7 @@ object TransitionMultiJvmSpec extends MultiNodeConfig {
val fifth = role("fifth")
commonConfig(debugConfig(on = false).
withFallback(ConfigFactory.parseString(
"akka.cluster.periodic-tasks-initial-delay = 300 s # turn off all periodic tasks")).
withFallback(ConfigFactory.parseString("akka.cluster.periodic-tasks-initial-delay = 300 s # turn off all periodic tasks")).
withFallback(MultiNodeClusterSpec.clusterConfig))
}
@ -396,7 +395,7 @@ abstract class TransitionSpec
seenLatestGossip must be(Set(fifth))
}
testConductor.enter("after-second-unavailble")
enterBarrier("after-second-unavailble")
// spread the word
val gossipRound = List(fifth, fourth, third, first, third, fourth, fifth)
@ -414,7 +413,7 @@ abstract class TransitionSpec
awaitMemberStatus(second, Down)
}
testConductor.enter("after-second-down")
enterBarrier("after-second-down")
// spread the word
val gossipRound2 = List(third, fourth, fifth, first, third, fourth, fifth)

View file

@ -6,6 +6,9 @@ package akka.cluster
import akka.actor.Address
import akka.testkit.{ LongRunningTest, AkkaSpec }
import scala.collection.immutable.TreeMap
import akka.util.duration._
import akka.util.Duration
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class AccrualFailureDetectorSpec extends AkkaSpec("""
@ -27,33 +30,72 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
timeGenerator
}
val defaultFakeTimeIntervals = Vector.fill(20)(1000L)
def createFailureDetector(
threshold: Double = 8.0,
maxSampleSize: Int = 1000,
minStdDeviation: Duration = 10.millis,
acceptableLostDuration: Duration = Duration.Zero,
firstHeartbeatEstimate: Duration = 1.second,
clock: () Long = fakeTimeGenerator(defaultFakeTimeIntervals)): AccrualFailureDetector =
new AccrualFailureDetector(system,
threshold,
maxSampleSize,
minStdDeviation,
acceptableLostDuration,
firstHeartbeatEstimate = firstHeartbeatEstimate,
clock = clock)
"use good enough cumulative distribution function" in {
val fd = createFailureDetector()
fd.cumulativeDistributionFunction(0.0, 0, 1) must be(0.5 plusOrMinus (0.001))
fd.cumulativeDistributionFunction(0.6, 0, 1) must be(0.7257 plusOrMinus (0.001))
fd.cumulativeDistributionFunction(1.5, 0, 1) must be(0.9332 plusOrMinus (0.001))
fd.cumulativeDistributionFunction(2.0, 0, 1) must be(0.97725 plusOrMinus (0.01))
fd.cumulativeDistributionFunction(2.5, 0, 1) must be(0.9379 plusOrMinus (0.1))
fd.cumulativeDistributionFunction(3.5, 0, 1) must be(0.99977 plusOrMinus (0.1))
fd.cumulativeDistributionFunction(4.0, 0, 1) must be(0.99997 plusOrMinus (0.1))
for (x :: y :: Nil (0.0 to 4.0 by 0.1).toList.sliding(2)) {
fd.cumulativeDistributionFunction(x, 0, 1) must be < (
fd.cumulativeDistributionFunction(y, 0, 1))
}
fd.cumulativeDistributionFunction(2.2, 2.0, 0.3) must be(0.7475 plusOrMinus (0.001))
}
"return realistic phi values" in {
val fd = createFailureDetector()
val test = TreeMap(0 -> 0.0, 500 -> 0.1, 1000 -> 0.3, 1200 -> 1.6, 1400 -> 4.7, 1600 -> 10.8, 1700 -> 15.3)
for ((timeDiff, expectedPhi) test) {
fd.phi(timeDiff = timeDiff, mean = 1000.0, stdDeviation = 100.0) must be(expectedPhi plusOrMinus (0.1))
}
// larger stdDeviation results => lower phi
fd.phi(timeDiff = 1100, mean = 1000.0, stdDeviation = 500.0) must be < (
fd.phi(timeDiff = 1100, mean = 1000.0, stdDeviation = 100.0))
}
"return phi value of 0.0 on startup for each address, when no heartbeats" in {
val fd = new AccrualFailureDetector(system)
val fd = createFailureDetector()
fd.phi(conn) must be(0.0)
fd.phi(conn2) must be(0.0)
}
"return phi based on guess when only one heartbeat" in {
// 1 second ticks
val timeInterval = Vector.fill(30)(1000L)
val fd = new AccrualFailureDetector(system,
timeMachine = fakeTimeGenerator(timeInterval))
val timeInterval = List[Long](0, 1000, 1000, 1000, 1000)
val fd = createFailureDetector(firstHeartbeatEstimate = 1.seconds,
clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.phi(conn) must be > (0.0)
// let time go
for (n 2 to 8)
fd.phi(conn) must be < (4.0)
for (n 9 to 18)
fd.phi(conn) must be < (8.0)
fd.phi(conn) must be > (8.0)
fd.phi(conn) must be(0.3 plusOrMinus 0.2)
fd.phi(conn) must be(4.5 plusOrMinus 0.3)
fd.phi(conn) must be > (15.0)
}
"return phi value using first interval after second heartbeat" in {
val timeInterval = List[Long](0, 100, 100, 100)
val fd = new AccrualFailureDetector(system,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.phi(conn) must be > (0.0)
@ -63,8 +105,7 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
"mark node as available after a series of successful heartbeats" in {
val timeInterval = List[Long](0, 1000, 100, 100)
val fd = new AccrualFailureDetector(system,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.heartbeat(conn)
@ -75,8 +116,7 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
"mark node as dead after explicit removal of connection" in {
val timeInterval = List[Long](0, 1000, 100, 100, 100)
val fd = new AccrualFailureDetector(system,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.heartbeat(conn)
@ -89,8 +129,7 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
"mark node as available after explicit removal of connection and receiving heartbeat again" in {
val timeInterval = List[Long](0, 1000, 100, 1100, 1100, 1100, 1100, 1100, 100)
val fd = new AccrualFailureDetector(system,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn) //0
@ -112,40 +151,65 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
}
"mark node as dead if heartbeat are missed" in {
val timeInterval = List[Long](0, 1000, 100, 100, 5000)
val timeInterval = List[Long](0, 1000, 100, 100, 7000)
val ft = fakeTimeGenerator(timeInterval)
val fd = new AccrualFailureDetector(system, threshold = 3,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(threshold = 3, clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn) //0
fd.heartbeat(conn) //1000
fd.heartbeat(conn) //1100
fd.isAvailable(conn) must be(true) //1200
fd.isAvailable(conn) must be(false) //6200
fd.isAvailable(conn) must be(false) //8200
}
"mark node as available if it starts heartbeat again after being marked dead due to detection of failure" in {
val timeInterval = List[Long](0, 1000, 100, 1100, 5000, 100, 1000, 100, 100)
val fd = new AccrualFailureDetector(system, threshold = 3,
timeMachine = fakeTimeGenerator(timeInterval))
val timeInterval = List[Long](0, 1000, 100, 1100, 7000, 100, 1000, 100, 100)
val fd = createFailureDetector(threshold = 3, clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn) //0
fd.heartbeat(conn) //1000
fd.heartbeat(conn) //1100
fd.isAvailable(conn) must be(true) //1200
fd.isAvailable(conn) must be(false) //6200
fd.heartbeat(conn) //6300
fd.heartbeat(conn) //7300
fd.heartbeat(conn) //7400
fd.isAvailable(conn) must be(false) //8200
fd.heartbeat(conn) //8300
fd.heartbeat(conn) //9300
fd.heartbeat(conn) //9400
fd.isAvailable(conn) must be(true) //7500
fd.isAvailable(conn) must be(true) //9500
}
"accept some configured missing heartbeats" in {
val timeInterval = List[Long](0, 1000, 1000, 1000, 4000, 1000, 1000)
val fd = createFailureDetector(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.isAvailable(conn) must be(true)
fd.heartbeat(conn)
fd.isAvailable(conn) must be(true)
}
"fail after configured acceptable missing heartbeats" in {
val timeInterval = List[Long](0, 1000, 1000, 1000, 1000, 1000, 500, 500, 5000)
val fd = createFailureDetector(acceptableLostDuration = 3.seconds, clock = fakeTimeGenerator(timeInterval))
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.heartbeat(conn)
fd.isAvailable(conn) must be(true)
fd.heartbeat(conn)
fd.isAvailable(conn) must be(false)
}
"use maxSampleSize heartbeats" in {
val timeInterval = List[Long](0, 100, 100, 100, 100, 600, 1000, 1000, 1000, 1000, 1000)
val fd = new AccrualFailureDetector(system, maxSampleSize = 3,
timeMachine = fakeTimeGenerator(timeInterval))
val fd = createFailureDetector(maxSampleSize = 3, clock = fakeTimeGenerator(timeInterval))
// 100 ms interval
fd.heartbeat(conn) //0
@ -163,4 +227,33 @@ class AccrualFailureDetectorSpec extends AkkaSpec("""
}
}
"Statistics for heartbeats" must {
"calculate correct mean and variance" in {
val samples = Seq(100, 200, 125, 340, 130)
val stats = (HeartbeatHistory(maxSampleSize = 20) /: samples) { (stats, value) stats :+ value }
stats.mean must be(179.0 plusOrMinus 0.00001)
stats.variance must be(7584.0 plusOrMinus 0.00001)
}
"have 0.0 variance for one sample" in {
(HeartbeatHistory(600) :+ 1000L).variance must be(0.0 plusOrMinus 0.00001)
}
"be capped by the specified maxSampleSize" in {
val history3 = HeartbeatHistory(maxSampleSize = 3) :+ 100 :+ 110 :+ 90
history3.mean must be(100.0 plusOrMinus 0.00001)
history3.variance must be(66.6666667 plusOrMinus 0.00001)
val history4 = history3 :+ 140
history4.mean must be(113.333333 plusOrMinus 0.00001)
history4.variance must be(422.222222 plusOrMinus 0.00001)
val history5 = history4 :+ 80
history5.mean must be(103.333333 plusOrMinus 0.00001)
history5.variance must be(688.88888889 plusOrMinus 0.00001)
}
}
}

View file

@ -16,9 +16,11 @@ class ClusterConfigSpec extends AkkaSpec {
"be able to parse generic cluster config elements" in {
val settings = new ClusterSettings(system.settings.config, system.name)
import settings._
FailureDetectorThreshold must be(8)
FailureDetectorThreshold must be(8.0 plusOrMinus 0.0001)
FailureDetectorMaxSampleSize must be(1000)
FailureDetectorImplementationClass must be(None)
FailureDetectorImplementationClass must be(classOf[AccrualFailureDetector].getName)
FailureDetectorMinStdDeviation must be(100 millis)
FailureDetectorAcceptableHeartbeatPause must be(3 seconds)
NodeToJoin must be(None)
PeriodicTasksInitialDelay must be(1 seconds)
GossipInterval must be(1 second)

View file

@ -196,4 +196,4 @@ Licenses for Dependency Libraries
---------------------------------
Each dependency and its license can be seen in the project build file (the comment on the side of each dependency):
`<https://github.com/akka/akka/blob/master/project/build/AkkaProject.scala#L127>`_
`<https://github.com/akka/akka/blob/master/project/AkkaBuild.scala#L497>`_

View file

@ -122,7 +122,8 @@ akka {
# (I) Length in akka.time-unit how long core threads will be kept alive if idling
execution-pool-keepalive = 60s
# (I) Size of the core pool of the remote execution unit
# (I) Size in number of threads of the core pool of the remote execution unit.
# A value of 0 will turn this off, which is can lead to deadlocks under some configurations!
execution-pool-size = 4
# (I) Maximum channel size, 0 for off
@ -204,10 +205,10 @@ akka {
# There are three options, in increasing order of security:
# "" or SecureRandom => (default)
# "SHA1PRNG" => Can be slow because of blocking issues on Linux
# "AES128CounterRNGFast" => fastest startup and based on AES encryption algorithm
# "AES128CounterSecureRNG" => fastest startup and based on AES encryption algorithm
# The following use one of 3 possible seed sources, depending on availability: /dev/random, random.org and SecureRandom (provided by Java)
# "AES128CounterRNGSecure"
# "AES256CounterRNGSecure" (Install JCE Unlimited Strength Jurisdiction Policy Files first)
# "AES128CounterInetRNG"
# "AES256CounterInetRNG" (Install JCE Unlimited Strength Jurisdiction Policy Files first)
# Setting a value here may require you to supply the appropriate cipher suite (see enabled-algorithms section above)
random-number-generator = ""
}

View file

@ -106,6 +106,7 @@ case class RemoteServerShutdown(
case class RemoteServerError(
@BeanProperty val cause: Throwable,
@transient @BeanProperty remote: RemoteTransport) extends RemoteServerLifeCycleEvent {
override def logLevel: Logging.LogLevel = Logging.ErrorLevel
override def toString: String = "RemoteServerError@" + remote + "] Error[" + cause + "]"
}

View file

@ -18,6 +18,7 @@ import akka.AkkaException
import akka.event.Logging
import akka.actor.{ DeadLetter, Address, ActorRef }
import akka.util.{ NonFatal, Switch }
import org.jboss.netty.handler.ssl.SslHandler
/**
* This is the abstract baseclass for netty remote clients, currently there's only an
@ -115,15 +116,27 @@ private[akka] class ActiveRemoteClient private[akka] (
*/
def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = {
def sendSecureCookie(connection: ChannelFuture) {
val handshake = RemoteControlProtocol.newBuilder.setCommandType(CommandType.CONNECT)
if (settings.SecureCookie.nonEmpty) handshake.setCookie(settings.SecureCookie.get)
handshake.setOrigin(RemoteProtocol.AddressProtocol.newBuilder
.setSystem(localAddress.system)
.setHostname(localAddress.host.get)
.setPort(localAddress.port.get)
.build)
connection.getChannel.write(netty.createControlEnvelope(handshake.build))
// Returns whether the handshake was written to the channel or not
def sendSecureCookie(connection: ChannelFuture): Boolean = {
val future =
if (!connection.isSuccess || !settings.EnableSSL) connection
else connection.getChannel.getPipeline.get[SslHandler](classOf[SslHandler]).handshake().awaitUninterruptibly()
if (!future.isSuccess) {
notifyListeners(RemoteClientError(future.getCause, netty, remoteAddress))
false
} else {
ChannelAddress.set(connection.getChannel, Some(remoteAddress))
val handshake = RemoteControlProtocol.newBuilder.setCommandType(CommandType.CONNECT)
if (settings.SecureCookie.nonEmpty) handshake.setCookie(settings.SecureCookie.get)
handshake.setOrigin(RemoteProtocol.AddressProtocol.newBuilder
.setSystem(localAddress.system)
.setHostname(localAddress.host.get)
.setPort(localAddress.port.get)
.build)
connection.getChannel.write(netty.createControlEnvelope(handshake.build))
true
}
}
def attemptReconnect(): Boolean = {
@ -131,14 +144,7 @@ private[akka] class ActiveRemoteClient private[akka] (
log.debug("Remote client reconnecting to [{}|{}]", remoteAddress, remoteIP)
connection = bootstrap.connect(new InetSocketAddress(remoteIP, remoteAddress.port.get))
openChannels.add(connection.awaitUninterruptibly.getChannel) // Wait until the connection attempt succeeds or fails.
if (!connection.isSuccess) {
notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
false
} else {
sendSecureCookie(connection)
true
}
sendSecureCookie(connection)
}
runSwitch switchOn {
@ -163,24 +169,19 @@ private[akka] class ActiveRemoteClient private[akka] (
openChannels.add(connection.awaitUninterruptibly.getChannel) // Wait until the connection attempt succeeds or fails.
if (!connection.isSuccess) {
notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
false
} else {
ChannelAddress.set(connection.getChannel, Some(remoteAddress))
sendSecureCookie(connection)
if (sendSecureCookie(connection)) {
notifyListeners(RemoteClientStarted(netty, remoteAddress))
true
} else {
connection.getChannel.close()
openChannels.remove(connection.getChannel)
false
}
} match {
case true true
case false if reconnectIfAlreadyConnected
connection.getChannel.close()
openChannels.remove(connection.getChannel)
log.debug("Remote client reconnecting to [{}]", remoteAddress)
attemptReconnect()
case false false
}
}

View file

@ -24,7 +24,7 @@ import akka.remote.{ RemoteTransportException, RemoteTransport, RemoteActorRefPr
import akka.util.NonFatal
import akka.actor.{ ExtendedActorSystem, Address, ActorRef }
object ChannelAddress extends ChannelLocal[Option[Address]] {
private[akka] object ChannelAddress extends ChannelLocal[Option[Address]] {
override def initialValue(ch: Channel): Option[Address] = None
}
@ -54,9 +54,7 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider
* in implementations of ChannelPipelineFactory.
*/
def apply(handlers: Seq[ChannelHandler]): DefaultChannelPipeline =
handlers.foldLeft(new DefaultChannelPipeline) {
(pipe, handler) pipe.addLast(Logging.simpleName(handler.getClass), handler); pipe
}
(new DefaultChannelPipeline /: handlers) { (p, h) p.addLast(Logging.simpleName(h.getClass), h); p }
/**
* Constructs the NettyRemoteTransport default pipeline with the give head handler, which
@ -65,21 +63,18 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider
* @param withTimeout determines whether an IdleStateHandler shall be included
*/
def apply(endpoint: Seq[ChannelHandler], withTimeout: Boolean, isClient: Boolean): ChannelPipelineFactory =
new ChannelPipelineFactory {
def getPipeline = apply(defaultStack(withTimeout, isClient) ++ endpoint)
}
new ChannelPipelineFactory { override def getPipeline = apply(defaultStack(withTimeout, isClient) ++ endpoint) }
/**
* Construct a default protocol stack, excluding the head handler (i.e. the one which
* actually dispatches the received messages to the local target actors).
*/
def defaultStack(withTimeout: Boolean, isClient: Boolean): Seq[ChannelHandler] =
(if (settings.EnableSSL) NettySSLSupport(settings, NettyRemoteTransport.this.log, isClient) :: Nil else Nil) :::
(if (withTimeout) timeout :: Nil else Nil) :::
(if (settings.EnableSSL) List(NettySSLSupport(settings, NettyRemoteTransport.this.log, isClient)) else Nil) :::
(if (withTimeout) List(timeout) else Nil) :::
msgFormat :::
authenticator :::
executionHandler ::
Nil
executionHandler
/**
* Construct an IdleStateHandler which uses [[akka.remote.netty.NettyRemoteTransport]].timer.
@ -103,20 +98,22 @@ private[akka] class NettyRemoteTransport(_system: ExtendedActorSystem, _provider
* happen on a netty thread (that could be bad if re-sending over the network for
* remote-deployed actors).
*/
val executionHandler = new ExecutionHandler(new OrderedMemoryAwareThreadPoolExecutor(
settings.ExecutionPoolSize,
settings.MaxChannelMemorySize,
settings.MaxTotalMemorySize,
settings.ExecutionPoolKeepalive.length,
settings.ExecutionPoolKeepalive.unit,
system.threadFactory))
val executionHandler = if (settings.ExecutionPoolSize != 0)
List(new ExecutionHandler(new OrderedMemoryAwareThreadPoolExecutor(
settings.ExecutionPoolSize,
settings.MaxChannelMemorySize,
settings.MaxTotalMemorySize,
settings.ExecutionPoolKeepalive.length,
settings.ExecutionPoolKeepalive.unit,
system.threadFactory)))
else Nil
/**
* Construct and authentication handler which uses the SecureCookie to somewhat
* protect the TCP port from unauthorized use (dont rely on it too much, though,
* as this is NOT a cryptographic feature).
*/
def authenticator = if (settings.RequireCookie) new RemoteServerAuthenticationHandler(settings.SecureCookie) :: Nil else Nil
def authenticator = if (settings.RequireCookie) List(new RemoteServerAuthenticationHandler(settings.SecureCookie)) else Nil
}
/**

View file

@ -9,8 +9,8 @@ import javax.net.ssl.{ KeyManagerFactory, TrustManager, TrustManagerFactory, SSL
import akka.remote.RemoteTransportException
import akka.event.LoggingAdapter
import java.io.{ IOException, FileNotFoundException, FileInputStream }
import java.security.{ SecureRandom, GeneralSecurityException, KeyStore, Security }
import akka.security.provider.AkkaProvider
import java.security._
/**
* Used for adding SSL support to Netty pipeline
@ -18,8 +18,7 @@ import akka.security.provider.AkkaProvider
*/
private[akka] object NettySSLSupport {
val akka = new AkkaProvider
Security.addProvider(akka)
Security addProvider AkkaProvider
/**
* Construct a SSLHandler which can be inserted into a Netty server/client pipeline
@ -33,17 +32,20 @@ private[akka] object NettySSLSupport {
* Using /dev/./urandom is only necessary when using SHA1PRNG on Linux
* <quote>Use 'new SecureRandom()' instead of 'SecureRandom.getInstance("SHA1PRNG")'</quote> to avoid having problems
*/
sourceOfRandomness foreach { path System.setProperty("java.security.egd", path) }
sourceOfRandomness foreach { path
System.setProperty("java.security.egd", path)
System.setProperty("securerandom.source", path)
}
val rng = rngName match {
case Some(r @ ("AES128CounterRNGFast" | "AES128CounterRNGSecure" | "AES256CounterRNGSecure"))
case Some(r @ ("AES128CounterSecureRNG" | "AES128CounterInetRNG" | "AES256CounterInetRNG"))
log.debug("SSL random number generator set to: {}", r)
SecureRandom.getInstance(r, akka)
case Some("SHA1PRNG")
log.debug("SSL random number generator set to: SHA1PRNG")
// This needs /dev/urandom to be the source on Linux to prevent problems with /dev/random blocking
SecureRandom.getInstance(r, AkkaProvider)
case Some(s @ ("SHA1PRNG" | "NativePRNG"))
log.debug("SSL random number generator set to: " + s)
// SHA1PRNG needs /dev/urandom to be the source on Linux to prevent problems with /dev/random blocking
// However, this also makes the seed source insecure as the seed is reused to avoid blocking (not a problem on FreeBSD).
SecureRandom.getInstance("SHA1PRNG")
SecureRandom.getInstance(s)
case Some(unknown)
log.debug("Unknown SSLRandomNumberGenerator [{}] falling back to SecureRandom", unknown)
new SecureRandom
@ -60,12 +62,18 @@ private[akka] object NettySSLSupport {
def constructClientContext(settings: NettySettings, log: LoggingAdapter, trustStorePath: String, trustStorePassword: String, protocol: String): Option[SSLContext] =
try {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
trustStore.load(new FileInputStream(trustStorePath), trustStorePassword.toCharArray) //FIXME does the FileInputStream need to be closed?
trustManagerFactory.init(trustStore)
val trustManagers: Array[TrustManager] = trustManagerFactory.getTrustManagers
Option(SSLContext.getInstance(protocol)) map { ctx ctx.init(null, trustManagers, initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log)); ctx }
val rng = initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log)
val trustManagers: Array[TrustManager] = {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
trustManagerFactory.init({
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
val fin = new FileInputStream(trustStorePath)
try trustStore.load(fin, trustStorePassword.toCharArray) finally fin.close()
trustStore
})
trustManagerFactory.getTrustManagers
}
Option(SSLContext.getInstance(protocol)) map { ctx ctx.init(null, trustManagers, rng); ctx }
} catch {
case e: FileNotFoundException throw new RemoteTransportException("Client SSL connection could not be established because trust store could not be loaded", e)
case e: IOException throw new RemoteTransportException("Client SSL connection could not be established because: " + e.getMessage, e)
@ -82,10 +90,12 @@ private[akka] object NettySSLSupport {
}) match {
case Some(context)
log.debug("Using client SSL context to create SSLEngine ...")
val sslEngine = context.createSSLEngine
sslEngine.setUseClientMode(true)
sslEngine.setEnabledCipherSuites(settings.SSLEnabledAlgorithms.toArray.map(_.toString))
new SslHandler(sslEngine)
new SslHandler({
val sslEngine = context.createSSLEngine
sslEngine.setUseClientMode(true)
sslEngine.setEnabledCipherSuites(settings.SSLEnabledAlgorithms.toArray)
sslEngine
})
case None
throw new GeneralSecurityException(
"""Failed to initialize client SSL because SSL context could not be found." +
@ -101,11 +111,15 @@ private[akka] object NettySSLSupport {
def constructServerContext(settings: NettySettings, log: LoggingAdapter, keyStorePath: String, keyStorePassword: String, protocol: String): Option[SSLContext] =
try {
val rng = initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log)
val factory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
keyStore.load(new FileInputStream(keyStorePath), keyStorePassword.toCharArray) //FIXME does the FileInputStream need to be closed?
factory.init(keyStore, keyStorePassword.toCharArray)
Option(SSLContext.getInstance(protocol)) map { ctx ctx.init(factory.getKeyManagers, null, initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, log)); ctx }
factory.init({
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
val fin = new FileInputStream(keyStorePath)
try keyStore.load(fin, keyStorePassword.toCharArray) finally fin.close()
keyStore
}, keyStorePassword.toCharArray)
Option(SSLContext.getInstance(protocol)) map { ctx ctx.init(factory.getKeyManagers, null, rng); ctx }
} catch {
case e: FileNotFoundException throw new RemoteTransportException("Server SSL connection could not be established because key store could not be loaded", e)
case e: IOException throw new RemoteTransportException("Server SSL connection could not be established because: " + e.getMessage, e)
@ -121,7 +135,7 @@ private[akka] object NettySSLSupport {
log.debug("Using server SSL context to create SSLEngine ...")
val sslEngine = context.createSSLEngine
sslEngine.setUseClientMode(false)
sslEngine.setEnabledCipherSuites(settings.SSLEnabledAlgorithms.toArray.map(_.toString))
sslEngine.setEnabledCipherSuites(settings.SSLEnabledAlgorithms.toArray)
new SslHandler(sslEngine)
case None throw new GeneralSecurityException(
"""Failed to initialize server SSL because SSL context could not be found.

View file

@ -8,6 +8,7 @@ import akka.util.Duration
import java.util.concurrent.TimeUnit._
import java.net.InetAddress
import akka.ConfigurationException
import scala.collection.JavaConverters.iterableAsScalaIterableConverter
private[akka] class NettySettings(config: Config, val systemName: String) {
@ -72,7 +73,7 @@ private[akka] class NettySettings(config: Config, val systemName: String) {
val ExecutionPoolKeepalive: Duration = Duration(getMilliseconds("execution-pool-keepalive"), MILLISECONDS)
val ExecutionPoolSize: Int = getInt("execution-pool-size") match {
case sz if sz < 1 throw new IllegalArgumentException("akka.remote.netty.execution-pool-size is less than 1")
case sz if sz < 0 throw new IllegalArgumentException("akka.remote.netty.execution-pool-size is less than 0")
case sz sz
}
@ -106,7 +107,7 @@ private[akka] class NettySettings(config: Config, val systemName: String) {
case password Some(password)
}
val SSLEnabledAlgorithms = getStringList("ssl.enabled-algorithms").toArray.toSet
val SSLEnabledAlgorithms = iterableAsScalaIterableConverter(getStringList("ssl.enabled-algorithms")).asScala.toSet[String]
val SSLProtocol = getString("ssl.protocol") match {
case "" None

View file

@ -7,12 +7,16 @@ import org.uncommons.maths.random.{ AESCounterRNG, DefaultSeedGenerator }
/**
* Internal API
* This class is a wrapper around the 128-bit AESCounterRNG algorithm provided by http://maths.uncommons.org/
* It uses the default seed generator which uses one of the following 3 random seed sources:
* Depending on availability: /dev/random, random.org and SecureRandom (provided by Java)
* The only method used by netty ssl is engineNextBytes(bytes)
*/
class AES128CounterRNGSecure extends java.security.SecureRandomSpi {
class AES128CounterInetRNG extends java.security.SecureRandomSpi {
private val rng = new AESCounterRNG()
/**
* This is managed internally only
* This is managed internally by AESCounterRNG
*/
override protected def engineSetSeed(seed: Array[Byte]): Unit = ()
@ -24,6 +28,7 @@ class AES128CounterRNGSecure extends java.security.SecureRandomSpi {
override protected def engineNextBytes(bytes: Array[Byte]): Unit = rng.nextBytes(bytes)
/**
* Unused method
* Returns the given number of seed bytes. This call may be used to
* seed other random number generators.
*

View file

@ -4,16 +4,18 @@
package akka.security.provider
import org.uncommons.maths.random.{ AESCounterRNG, SecureRandomSeedGenerator }
import java.security.SecureRandom
/**
* Internal API
* This class is a wrapper around the AESCounterRNG algorithm provided by http://maths.uncommons.org/ *
* The only method used by netty ssl is engineNextBytes(bytes)
* This RNG is good to use to prevent startup delay when you don't have Internet access to random.org
*/
class AES128CounterRNGFast extends java.security.SecureRandomSpi {
class AES128CounterSecureRNG extends java.security.SecureRandomSpi {
private val rng = new AESCounterRNG(new SecureRandomSeedGenerator())
/**
* This is managed internally only
* This is managed internally by AESCounterRNG
*/
override protected def engineSetSeed(seed: Array[Byte]): Unit = ()
@ -25,12 +27,13 @@ class AES128CounterRNGFast extends java.security.SecureRandomSpi {
override protected def engineNextBytes(bytes: Array[Byte]): Unit = rng.nextBytes(bytes)
/**
* Unused method
* Returns the given number of seed bytes. This call may be used to
* seed other random number generators.
*
* @param numBytes the number of seed bytes to generate.
* @return the seed bytes.
*/
override protected def engineGenerateSeed(numBytes: Int): Array[Byte] = (new SecureRandom).generateSeed(numBytes)
override protected def engineGenerateSeed(numBytes: Int): Array[Byte] = (new SecureRandomSeedGenerator()).generateSeed(numBytes)
}

View file

@ -7,12 +7,22 @@ import org.uncommons.maths.random.{ AESCounterRNG, DefaultSeedGenerator }
/**
* Internal API
* This class is a wrapper around the 256-bit AESCounterRNG algorithm provided by http://maths.uncommons.org/
* It uses the default seed generator which uses one of the following 3 random seed sources:
* Depending on availability: /dev/random, random.org and SecureRandom (provided by Java)
* The only method used by netty ssl is engineNextBytes(bytes)
*/
class AES256CounterRNGSecure extends java.security.SecureRandomSpi {
private val rng = new AESCounterRNG(32) // Magic number is magic
class AES256CounterInetRNG extends java.security.SecureRandomSpi {
/**
* From AESCounterRNG API docs:
* Valid values are 16 (128 bits), 24 (192 bits) and 32 (256 bits).
* Any other values will result in an exception from the AES implementation.
*/
private val AES_256_BIT = 32 // Magic number is magic
private val rng = new AESCounterRNG(AES_256_BIT)
/**
* This is managed internally only
* This is managed internally by AESCounterRNG
*/
override protected def engineSetSeed(seed: Array[Byte]): Unit = ()
@ -24,6 +34,7 @@ class AES256CounterRNGSecure extends java.security.SecureRandomSpi {
override protected def engineNextBytes(bytes: Array[Byte]): Unit = rng.nextBytes(bytes)
/**
* Unused method
* Returns the given number of seed bytes. This call may be used to
* seed other random number generators.
*

View file

@ -3,23 +3,23 @@
*/
package akka.security.provider
import java.security.{ PrivilegedAction, AccessController, Provider }
import java.security.{ PrivilegedAction, AccessController, Provider, Security }
/**
* A provider that for AES128CounterRNGFast, a cryptographically secure random number generator through SecureRandom
*/
final class AkkaProvider extends Provider("Akka", 1.0, "Akka provider 1.0 that implements a secure AES random number generator") {
AccessController.doPrivileged(new PrivilegedAction[AkkaProvider] {
object AkkaProvider extends Provider("Akka", 1.0, "Akka provider 1.0 that implements a secure AES random number generator") {
AccessController.doPrivileged(new PrivilegedAction[this.type] {
def run = {
//SecureRandom
put("SecureRandom.AES128CounterRNGFast", "akka.security.provider.AES128CounterRNGFast")
put("SecureRandom.AES128CounterRNGSecure", "akka.security.provider.AES128CounterRNGSecure")
put("SecureRandom.AES256CounterRNGSecure", "akka.security.provider.AES256CounterRNGSecure")
put("SecureRandom.AES128CounterSecureRNG", classOf[AES128CounterSecureRNG].getName)
put("SecureRandom.AES128CounterInetRNG", classOf[AES128CounterInetRNG].getName)
put("SecureRandom.AES256CounterInetRNG", classOf[AES256CounterInetRNG].getName)
//Implementation type: software or hardware
put("SecureRandom.AES128CounterRNGFast ImplementedIn", "Software")
put("SecureRandom.AES128CounterRNGSecure ImplementedIn", "Software")
put("SecureRandom.AES256CounterRNGSecure ImplementedIn", "Software")
put("SecureRandom.AES128CounterSecureRNG ImplementedIn", "Software")
put("SecureRandom.AES128CounterInetRNG ImplementedIn", "Software")
put("SecureRandom.AES256CounterInetRNG ImplementedIn", "Software")
null //Magic null is magic
}
})

View file

@ -9,12 +9,12 @@ import com.typesafe.config._
import akka.dispatch.{ Await, Future }
import akka.pattern.ask
import java.io.File
import akka.event.{ NoLogging, LoggingAdapter }
import java.security.{ NoSuchAlgorithmException, SecureRandom, PrivilegedAction, AccessController }
import netty.{ NettySettings, NettySSLSupport }
import javax.net.ssl.SSLException
import akka.util.{ Timeout, Duration }
import akka.util.duration._
import akka.event.{ Logging, NoLogging, LoggingAdapter }
object Configuration {
// set this in your JAVA_OPTS to see all ssl debug info: "-Djavax.net.debug=ssl,keymanager"
@ -32,6 +32,7 @@ object Configuration {
remote.netty {
hostname = localhost
port = %d
ssl {
enable = on
trust-store = "%s"
@ -41,159 +42,104 @@ object Configuration {
sha1prng-random-source = "/dev/./urandom"
}
}
actor.deployment {
/blub.remote = "akka://remote-sys@localhost:12346"
/looker/child.remote = "akka://remote-sys@localhost:12346"
/looker/child/grandchild.remote = "akka://Ticket1978CommunicationSpec@localhost:12345"
}
}
"""
def getCipherConfig(cipher: String, enabled: String*): (String, Boolean, Config) = try {
case class CipherConfig(runTest: Boolean, config: Config, cipher: String, localPort: Int, remotePort: Int)
if (true) throw new IllegalArgumentException("Ticket1978*Spec isn't enabled")
def getCipherConfig(cipher: String, enabled: String*): CipherConfig = {
val localPort, remotePort = { val s = new java.net.ServerSocket(0); try s.getLocalPort finally s.close() }
try {
//if (true) throw new IllegalArgumentException("Ticket1978*Spec isn't enabled")
val config = ConfigFactory.parseString("akka.remote.netty.port=12345").withFallback(ConfigFactory.parseString(conf.format(trustStore, keyStore, cipher, enabled.mkString(", "))))
val fullConfig = config.withFallback(AkkaSpec.testConf).withFallback(ConfigFactory.load).getConfig("akka.remote.netty")
val settings = new NettySettings(fullConfig, "placeholder")
val config = ConfigFactory.parseString(conf.format(localPort, trustStore, keyStore, cipher, enabled.mkString(", ")))
val fullConfig = config.withFallback(AkkaSpec.testConf).withFallback(ConfigFactory.load).getConfig("akka.remote.netty")
val settings = new NettySettings(fullConfig, "placeholder")
val rng = NettySSLSupport.initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, NoLogging)
val rng = NettySSLSupport.initializeCustomSecureRandom(settings.SSLRandomNumberGenerator, settings.SSLRandomSource, NoLogging)
rng.nextInt() // Has to work
settings.SSLRandomNumberGenerator foreach { sRng rng.getAlgorithm == sRng || (throw new NoSuchAlgorithmException(sRng)) }
rng.nextInt() // Has to work
settings.SSLRandomNumberGenerator foreach { sRng rng.getAlgorithm == sRng || (throw new NoSuchAlgorithmException(sRng)) }
val engine = NettySSLSupport.initializeClientSSL(settings, NoLogging).getEngine
val gotAllSupported = enabled.toSet -- engine.getSupportedCipherSuites.toSet
val gotAllEnabled = enabled.toSet -- engine.getEnabledCipherSuites.toSet
gotAllSupported.isEmpty || (throw new IllegalArgumentException("Cipher Suite not supported: " + gotAllSupported))
gotAllEnabled.isEmpty || (throw new IllegalArgumentException("Cipher Suite not enabled: " + gotAllEnabled))
engine.getSupportedProtocols.contains(settings.SSLProtocol.get) || (throw new IllegalArgumentException("Protocol not supported: " + settings.SSLProtocol.get))
val engine = NettySSLSupport.initializeClientSSL(settings, NoLogging).getEngine
val gotAllSupported = enabled.toSet -- engine.getSupportedCipherSuites.toSet
val gotAllEnabled = enabled.toSet -- engine.getEnabledCipherSuites.toSet
gotAllSupported.isEmpty || (throw new IllegalArgumentException("Cipher Suite not supported: " + gotAllSupported))
gotAllEnabled.isEmpty || (throw new IllegalArgumentException("Cipher Suite not enabled: " + gotAllEnabled))
engine.getSupportedProtocols.contains(settings.SSLProtocol.get) || (throw new IllegalArgumentException("Protocol not supported: " + settings.SSLProtocol.get))
(cipher, true, config)
} catch {
case (_: IllegalArgumentException) | (_: NoSuchAlgorithmException) (cipher, false, AkkaSpec.testConf) // Cannot match against the message since the message might be localized :S
CipherConfig(true, config, cipher, localPort, remotePort)
} catch {
case (_: IllegalArgumentException) | (_: NoSuchAlgorithmException) CipherConfig(false, AkkaSpec.testConf, cipher, localPort, remotePort) // Cannot match against the message since the message might be localized :S
}
}
}
import Configuration.getCipherConfig
import Configuration.{ CipherConfig, getCipherConfig }
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978SHA1PRNGSpec extends Ticket1978CommunicationSpec(getCipherConfig("SHA1PRNG", "TLS_RSA_WITH_AES_128_CBC_SHA"))
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978AES128CounterRNGFastSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES128CounterRNGFast", "TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"))
class Ticket1978AES128CounterSecureRNGSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES128CounterSecureRNG", "TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"))
/**
* Both of the <quote>Secure</quote> variants require access to the Internet to access random.org.
* Both of the <quote>Inet</quote> variants require access to the Internet to access random.org.
*/
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978AES128CounterRNGSecureSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES128CounterRNGSecure", "TLS_RSA_WITH_AES_128_CBC_SHA"))
class Ticket1978AES128CounterInetRNGSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES128CounterInetRNG", "TLS_RSA_WITH_AES_128_CBC_SHA"))
/**
* Both of the <quote>Secure</quote> variants require access to the Internet to access random.org.
* Both of the <quote>Inet</quote> variants require access to the Internet to access random.org.
*/
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978AES256CounterRNGSecureSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES256CounterRNGSecure", "TLS_RSA_WITH_AES_256_CBC_SHA"))
class Ticket1978AES256CounterInetRNGSpec extends Ticket1978CommunicationSpec(getCipherConfig("AES256CounterInetRNG", "TLS_RSA_WITH_AES_256_CBC_SHA"))
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978DefaultRNGSecureSpec extends Ticket1978CommunicationSpec(getCipherConfig("", "TLS_RSA_WITH_AES_128_CBC_SHA"))
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978NonExistingRNGSecureSpec extends Ticket1978CommunicationSpec(("NonExistingRNG", false, AkkaSpec.testConf))
class Ticket1978CrappyRSAWithMD5OnlyHereToMakeSureThingsWorkSpec extends Ticket1978CommunicationSpec(getCipherConfig("", "SSL_RSA_WITH_NULL_MD5"))
abstract class Ticket1978CommunicationSpec(val cipherEnabledconfig: (String, Boolean, Config)) extends AkkaSpec(cipherEnabledconfig._3) with ImplicitSender {
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class Ticket1978NonExistingRNGSecureSpec extends Ticket1978CommunicationSpec(CipherConfig(false, AkkaSpec.testConf, "NonExistingRNG", 12345, 12346))
implicit val timeout: Timeout = Timeout(5 seconds)
abstract class Ticket1978CommunicationSpec(val cipherConfig: CipherConfig) extends AkkaSpec(cipherConfig.config) with ImplicitSender {
implicit val timeout: Timeout = Timeout(10 seconds)
import RemoteCommunicationSpec._
val other = ActorSystem("remote-sys", ConfigFactory.parseString("akka.remote.netty.port=12346").withFallback(system.settings.config))
lazy val other: ActorSystem = ActorSystem(
"remote-sys",
ConfigFactory.parseString("akka.remote.netty.port=" + cipherConfig.remotePort).withFallback(system.settings.config))
override def atTermination() {
other.shutdown()
other.awaitTermination()
if (cipherConfig.runTest) {
other.shutdown()
other.awaitTermination()
}
}
"SSL Remoting" must {
if (cipherEnabledconfig._2) {
val remote = other.actorOf(Props(new Actor { def receive = { case "ping" sender ! (("pong", sender)) } }), "echo")
("-") must {
if (cipherConfig.runTest) {
val ignoreMe = other.actorOf(Props(new Actor { def receive = { case ("ping", x) sender ! ((("pong", x), sender)) } }), "echo")
val otherAddress = other.asInstanceOf[ExtendedActorSystem].provider.asInstanceOf[RemoteActorRefProvider].transport.address
val here = system.actorFor("akka://remote-sys@localhost:12346/user/echo")
"support tell" in {
val here = system.actorFor(otherAddress.toString + "/user/echo")
"support remote look-ups" in {
here ! "ping"
expectMsgPF(timeout.duration) {
case ("pong", s: AnyRef) if s eq testActor true
}
}
"send error message for wrong address" ignore {
within(timeout.duration) {
EventFilter.error(start = "dropping", occurrences = 1).intercept {
system.actorFor("akka://remotesys@localhost:12346/user/echo") ! "ping"
}(other)
}
for (i 1 to 1000) here ! (("ping", i))
for (i 1 to 1000) expectMsgPF(timeout.duration) { case (("pong", i), `testActor`) true }
}
"support ask" in {
Await.result(here ? "ping", timeout.duration) match {
case ("pong", s: akka.pattern.PromiseActorRef) // good
case m fail(m + " was not (pong, AskActorRef)")
}
val here = system.actorFor(otherAddress.toString + "/user/echo")
val f = for (i 1 to 1000) yield here ? (("ping", i)) mapTo manifest[((String, Int), ActorRef)]
Await.result(Future.sequence(f), timeout.duration).map(_._1._1).toSet must be(Set("pong"))
}
"send dead letters on remote if actor does not exist" in {
within(timeout.duration) {
EventFilter.warning(pattern = "dead.*buh", occurrences = 1).intercept {
system.actorFor("akka://remote-sys@localhost:12346/does/not/exist") ! "buh"
}(other)
}
}
"create and supervise children on remote node" in {
within(timeout.duration) {
val r = system.actorOf(Props[Echo], "blub")
r.path.toString must be === "akka://remote-sys@localhost:12346/remote/Ticket1978CommunicationSpec@localhost:12345/user/blub"
r ! 42
expectMsg(42)
EventFilter[Exception]("crash", occurrences = 1).intercept {
r ! new Exception("crash")
}(other)
expectMsg("preRestart")
r ! 42
expectMsg(42)
}
}
"look-up actors across node boundaries" in {
within(timeout.duration) {
val l = system.actorOf(Props(new Actor {
def receive = {
case (p: Props, n: String) sender ! context.actorOf(p, n)
case s: String sender ! context.actorFor(s)
}
}), "looker")
l ! (Props[Echo], "child")
val r = expectMsgType[ActorRef]
r ! (Props[Echo], "grandchild")
val remref = expectMsgType[ActorRef]
remref.isInstanceOf[LocalActorRef] must be(true)
val myref = system.actorFor(system / "looker" / "child" / "grandchild")
myref.isInstanceOf[RemoteActorRef] must be(true)
myref ! 43
expectMsg(43)
lastSender must be theSameInstanceAs remref
r.asInstanceOf[RemoteActorRef].getParent must be(l)
system.actorFor("/user/looker/child") must be theSameInstanceAs r
Await.result(l ? "child/..", timeout.duration).asInstanceOf[AnyRef] must be theSameInstanceAs l
Await.result(system.actorFor(system / "looker" / "child") ? "..", timeout.duration).asInstanceOf[AnyRef] must be theSameInstanceAs l
}
}
"not fail ask across node boundaries" in {
val f = for (_ 1 to 1000) yield here ? "ping" mapTo manifest[(String, ActorRef)]
Await.result(Future.sequence(f), timeout.duration).map(_._1).toSet must be(Set("pong"))
}
} else {
"not be run when the cipher is not supported by the platform this test is currently being executed on" ignore {

View file

@ -15,12 +15,7 @@ akka {
actor.provider = "akka.remote.RemoteActorRefProvider"
remote.netty {
hostname = localhost
port = 12345
}
actor.deployment {
/blub.remote = "akka://remote-sys@localhost:12346"
/looker/child.remote = "akka://remote-sys@localhost:12346"
/looker/child/grandchild.remote = "akka://RemoteCommunicationSpec@localhost:12345"
port = 0
}
}
""") with ImplicitSender with DefaultTimeout {