diff --git a/akka-actor/src/main/scala/akka/util/CircuitBreaker.scala b/akka-actor/src/main/scala/akka/util/CircuitBreaker.scala index 64d353eac2..132b676b9a 100644 --- a/akka-actor/src/main/scala/akka/util/CircuitBreaker.scala +++ b/akka-actor/src/main/scala/akka/util/CircuitBreaker.scala @@ -4,200 +4,233 @@ package akka.util -// ==================================================================== // -// Based on code by Christopher Schmidt released under Apache 2 license // -// ==================================================================== // - -import scala.collection.immutable.HashMap import System._ -import java.util.concurrent.atomic.{ AtomicLong, AtomicReference, AtomicInteger } +import scala.annotation.tailrec + +import java.util.concurrent.atomic.AtomicReference /** - * Holder companion object for creating and retrieving all configured CircuitBreaker (CircuitBreaker) instances. - * (Enhancements could be to put some clever ThreadLocal stuff in here). *
- *   try withCircuitBreaker(circuitBreakerName) {
- *     // do something dangerous here
- *   } catch {
- *     case e: CircuitBreakerOpenException     ⇒ // handle...
- *     case e: CircuitBreakerHalfOpenException ⇒ // handle...
- *     case e: Exception                       ⇒ // handle...
+ *   // Create the CircuitBreaker
+ *   val circuitBreaker =
+ *     CircuitBreaker(CircuitBreaker.Config(60.seconds, 5))
+ *
+ *   // Configure the CircuitBreaker actions
+ *   circuitBreaker
+ *     onOpen {
+ *       ...
+ *     } onClose {
+ *       ...
+ *     } onHalfOpen {
+ *       ...
+ *     }
+ *
+ *   circuitBreaker {
+ *     ...
  *   }
  * 
*/ object CircuitBreaker { + case class Config(timeout: Duration, failureThreshold: Int) - private var circuitBreaker = Map.empty[String, CircuitBreaker] + private[akka] def apply(implicit config: Config): CircuitBreaker = new CircuitBreaker(config) + +} + +class CircuitBreaker private (val config: CircuitBreaker.Config) { + import CircuitBreaker._ + + private object InternalState { + def apply(circuitBreaker: CircuitBreaker): InternalState = + InternalState( + 0L, + Closed(circuitBreaker), + circuitBreaker.config.timeout.toMillis, + circuitBreaker.config.failureThreshold, + 0L, + 0) + } /** - * Factory mathod that creates a new CircuitBreaker with a given name and configuration. - * - * @param name name or id of the new CircuitBreaker - * @param config CircuitBreakerConfiguration to configure the new CircuitBreaker + * Represents the internal state of the CircuitBreaker. */ - def addCircuitBreaker(name: String, config: CircuitBreakerConfiguration): Unit = { - circuitBreaker.get(name) match { - case None ⇒ circuitBreaker += ((name, new CircuitBreakerImpl(config))) - case Some(x) ⇒ throw new IllegalArgumentException("CircuitBreaker [" + name + "] already configured") + private case class InternalState( + version: Long, + state: CircuitBreakerState, + timeout: Long, + failureThreshold: Int, + tripTime: Long, + failureCount: Int, + onOpenListeners: List[() ⇒ Unit] = Nil, + onCloseListeners: List[() ⇒ Unit] = Nil, + onHalfOpenListeners: List[() ⇒ Unit] = Nil) + + private[akka] trait CircuitBreakerState { + val circuitBreaker: CircuitBreaker + + def onError(e: Throwable) + + def preInvoke() + + def postInvoke() + } + + /** + * CircuitBreaker is CLOSED, normal operation. + */ + private[akka] case class Closed(circuitBreaker: CircuitBreaker) extends CircuitBreakerState { + + def onError(e: Throwable) = { + circuitBreaker.incrementFailureCount() + val currentCount = circuitBreaker.failureCount + val threshold = circuitBreaker.failureThreshold + if (currentCount >= threshold) circuitBreaker.trip() } + def preInvoke() {} + + def postInvoke() { circuitBreaker.resetFailureCount() } } - def hasCircuitBreaker(name: String) = circuitBreaker.contains(name) - /** - * CircuitBreaker retrieve method. - * - * @param name String name or id of the CircuitBreaker - * @return CircuitBreaker with name or id name + * CircuitBreaker is OPEN. Calls are failing fast. */ - private[akka] def apply(name: String): CircuitBreaker = { - circuitBreaker.get(name) match { - case Some(x) ⇒ x - case None ⇒ throw new IllegalArgumentException("CircuitBreaker [" + name + "] not configured") + private[akka] case class Open(circuitBreaker: CircuitBreaker) extends CircuitBreakerState { + + def onError(e: Throwable) {} + + def preInvoke() { + val now = currentTimeMillis + val elapsed = now - circuitBreaker.tripTime + if (elapsed <= circuitBreaker.timeout) + circuitBreaker.notifyOpen() + circuitBreaker.attemptReset() } - } -} -/** - * Basic Mixin for using CircuitBreaker Scope method. - */ -trait UsingCircuitBreaker { - def withCircuitBreaker[T](name: String)(f: ⇒ T): T = { - CircuitBreaker(name).invoke(f) - } -} - -/** - * @param timeout timout for trying again - * @param failureThreshold threshold of errors till breaker will open - */ -case class CircuitBreakerConfiguration(timeout: Long, failureThreshold: Int) - -/** - * Interface definition for CircuitBreaker - */ -private[akka] trait CircuitBreaker { - - /** - * increments and gets the actual failure count - * - * @return Int failure count - */ - var failureCount: Int - - /** - * @return Long milliseconds at trip - */ - var tripTime: Long - - /** - * function that has to be applied in CircuitBreaker scope - */ - def invoke[T](f: ⇒ T): T - - /** - * trip CircuitBreaker, store trip time - */ - def trip() - - /** - * sets failure count to 0 - */ - def resetFailureCount() - - /** - * set state to Half Open - */ - def attemptReset() - - /** - * reset CircuitBreaker to configured defaults - */ - def reset() - - /** - * @return Int configured failure threshold - */ - def failureThreshold: Int - - /** - * @return Long configured timeout - */ - def timeout: Long -} - -/** - * CircuitBreaker base class for all configuration things - * holds all thread safe (atomic) private members - */ -private[akka] abstract class CircuitBreakerBase(config: CircuitBreakerConfiguration) extends CircuitBreaker { - private var _state = new AtomicReference[States] - private var _failureThreshold = new AtomicInteger(config.failureThreshold) - private var _timeout = new AtomicLong(config.timeout) - private var _failureCount = new AtomicInteger(0) - private var _tripTime = new AtomicLong - - protected def state_=(s: States) { - _state.set(s) + def postInvoke() {} } - protected def state = _state.get + /** + * CircuitBreaker is HALF OPEN. Calls are still failing after timeout. + */ + private[akka] case class HalfOpen(circuitBreaker: CircuitBreaker) extends CircuitBreakerState { - def failureThreshold = _failureThreshold get + def onError(e: Throwable) { + circuitBreaker.trip() + circuitBreaker.notifyHalfOpen() + } - def timeout = _timeout get + def preInvoke() {} - def failureCount_=(i: Int) { - _failureCount.set(i) + def postInvoke() { circuitBreaker.reset() } } - def failureCount = _failureCount.incrementAndGet + private val ref = new AtomicReference(InternalState(this)) - def tripTime_=(l: Long) { - _tripTime.set(l) + def timeout = ref.get.timeout + + def failureThreshold = ref.get.failureThreshold + + def failureCount = ref.get.failureCount + + def tripTime = ref.get.tripTime + + @tailrec + final def incrementFailureCount() { + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + failureCount = oldState.failureCount + 1) + if (!ref.compareAndSet(oldState, newState)) incrementFailureCount() } - def tripTime = _tripTime.get - -} - -/** - * CircuitBreaker implementation class for changing states. - */ -private[akka] class CircuitBreakerImpl(config: CircuitBreakerConfiguration) extends CircuitBreakerBase(config) { - reset() - - def reset() { - resetFailureCount - state = new ClosedState(this) + @tailrec + final def reset() { + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + failureCount = 0, + state = Closed(this)) + if (!ref.compareAndSet(oldState, newState)) reset() } - def resetFailureCount() { - failureCount = 0 + @tailrec + final def resetFailureCount() { + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + failureCount = 0) + if (!ref.compareAndSet(oldState, newState)) resetFailureCount() } - def attemptReset() { - state = new HalfOpenState(this) + @tailrec + final def attemptReset() { + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + state = HalfOpen(this)) + if (!ref.compareAndSet(oldState, newState)) attemptReset() } - def trip() { - tripTime = currentTimeMillis - state = new OpenState(this) + @tailrec + final def trip() { + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + state = Open(this), + tripTime = currentTimeMillis) + if (!ref.compareAndSet(oldState, newState)) trip() } - def invoke[T](f: ⇒ T): T = { - state.preInvoke + def apply[T](body: ⇒ T): T = { + val oldState = ref.get + oldState.state.preInvoke() try { - val ret = f - state.postInvoke + val ret = body + oldState.state.postInvoke() ret } catch { - case e: Throwable ⇒ { - state.onError(e) + case e: Throwable ⇒ + oldState.state.onError(e) throw e - } } } -} \ No newline at end of file + + @tailrec + final def onClose(body: ⇒ Unit): CircuitBreaker = { + val f = () ⇒ body + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + onCloseListeners = f :: oldState.onCloseListeners) + if (!ref.compareAndSet(oldState, newState)) onClose(f) + else this + } + + @tailrec + final def onOpen(body: ⇒ Unit): CircuitBreaker = { + val f = () ⇒ body + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + onOpenListeners = f :: oldState.onOpenListeners) + if (!ref.compareAndSet(oldState, newState)) onOpen(() ⇒ f) + else this + } + + @tailrec + final def onHalfOpen(body: ⇒ Unit): CircuitBreaker = { + val f = () ⇒ body + val oldState = ref.get + val newState = oldState copy (version = oldState.version + 1, + onHalfOpenListeners = f :: oldState.onHalfOpenListeners) + if (!ref.compareAndSet(oldState, newState)) onHalfOpen(() ⇒ f) + else this + } + + def notifyOpen() { + ref.get.onOpenListeners foreach (f ⇒ f()) + } + + def notifyHalfOpen() { + ref.get.onHalfOpenListeners foreach (f ⇒ f()) + } + + def notifyClosed() { + ref.get.onCloseListeners foreach (f ⇒ f()) + } +}