diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala index f33047e4f9..af2eb376b8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala @@ -3,14 +3,12 @@ */ package akka.stream.impl -import scala.collection.immutable -import scala.util.{ Failure, Success } -import scala.util.control.NonFatal import org.reactivestreams.api.Processor import org.reactivestreams.spi.Subscriber -import akka.actor.{ Actor, ActorLogging, ActorRef, Props } +import akka.actor._ import akka.stream.MaterializerSettings import akka.event.LoggingReceive +import akka.stream.impl._ /** * INTERNAL API @@ -28,199 +26,171 @@ private[akka] object ActorProcessor { } } -class ActorProcessor[I, O]( final val impl: ActorRef) extends Processor[I, O] with ActorConsumerLike[I] with ActorProducerLike[O] +/** + * INTERNAL API + */ +private[akka] class ActorProcessor[I, O]( final val impl: ActorRef) extends Processor[I, O] with ActorConsumerLike[I] with ActorProducerLike[O] /** * INTERNAL API */ -private[akka] abstract class ActorProcessorImpl(val settings: MaterializerSettings) - extends Actor - with SubscriberManagement[Any] - with ActorLogging - with SoftShutdown { - - import ActorBasedFlowMaterializer._ - - type S = ActorSubscription[Any] - - override def maxBufferSize: Int = settings.maxFanOutBufferSize - override def initialBufferSize: Int = settings.initialFanOutBufferSize - override def createSubscription(subscriber: Subscriber[Any]): S = new ActorSubscription(self, subscriber) - - override def receive = waitingExposedPublisher - +private[akka] trait PrimaryInputs { + this: Actor ⇒ + // FIXME: have a NoInputs here to avoid nulls protected var primaryInputs: Inputs = _ - ////////////////////// Startup phases ////////////////////// + def initialInputBufferSize: Int + def maximumInputBufferSize: Int - var exposedPublisher: ActorPublisher[Any] = _ - - def waitingExposedPublisher: Receive = { - case ExposedPublisher(publisher) ⇒ - exposedPublisher = publisher - publisherExposed() - context.become(waitingForUpstream) - case _ ⇒ throw new IllegalStateException("The first message must be ExposedPublisher") - } - - // WARNING: DO NOT SEND messages from the constructor (that includes subscribing to other streams) since their reply - // might arrive earlier than ExposedPublisher. Override this method to schedule such events. - protected def publisherExposed(): Unit = () - - def waitingForUpstream: Receive = downstreamManagement orElse { + def waitingForUpstream: Receive = { case OnComplete ⇒ // Instead of introducing an edge case, handle it in the general way primaryInputs = EmptyInputs transitionToRunningWhenReady() case OnSubscribe(subscription) ⇒ assert(subscription != null) - primaryInputs = new BatchingInputBuffer(subscription, settings.initialInputBufferSize) + primaryInputs = new BatchingInputBuffer(subscription, initialInputBufferSize) transitionToRunningWhenReady() - case OnError(cause) ⇒ failureReceived(cause) + case OnError(cause) ⇒ primaryInputOnError(cause) } def transitionToRunningWhenReady(): Unit = if (primaryInputs ne null) { primaryInputs.prefetch() - transferState = initialTransferState - context.become(running) + primaryInputsReady() } - ////////////////////// Management of subscribers ////////////////////// + def upstreamManagement: Receive = { + case OnNext(element) ⇒ + primaryInputs.enqueueInputElement(element) + pumpInputs() + case OnComplete ⇒ + primaryInputs.complete() + primaryInputOnComplete() + pumpInputs() + case OnError(cause) ⇒ primaryInputOnError(cause) + } + + def pumpInputs(): Unit + def primaryInputsReady(): Unit + def primaryInputOnError(cause: Throwable): Unit + def primaryInputOnComplete(): Unit +} + +/** + * INTERNAL API + */ +private[akka] trait PrimaryOutputs { + this: Actor ⇒ + protected var exposedPublisher: ActorPublisher[Any] = _ + + def initialFanOutBufferSize: Int + def maxFanOutBufferSize: Int + + object PrimaryOutputs extends FanoutOutputs(maxFanOutBufferSize, initialFanOutBufferSize) { + override type S = ActorSubscription[Any] + override def createSubscription(subscriber: Subscriber[Any]): ActorSubscription[Any] = + new ActorSubscription(self, subscriber) + override def afterShutdown(completed: Boolean): Unit = primaryOutputsFinished(completed) + } + + def waitingExposedPublisher: Receive = { + case ExposedPublisher(publisher) ⇒ + exposedPublisher = publisher + primaryOutputsReady() + case _ ⇒ throw new IllegalStateException("The first message must be ExposedPublisher") + } - // All methods called here are implemented by SubscriberManagement def downstreamManagement: Receive = { case SubscribePending ⇒ subscribePending() case RequestMore(subscription, elements) ⇒ - moreRequested(subscription.asInstanceOf[S], elements) - pump() + PrimaryOutputs.handleRequest(subscription.asInstanceOf[ActorSubscription[Any]], elements) + pumpOutputs() case Cancel(subscription) ⇒ - unregisterSubscription(subscription.asInstanceOf[S]) - pump() + PrimaryOutputs.removeSubscription(subscription.asInstanceOf[ActorSubscription[Any]]) + pumpOutputs() } private def subscribePending(): Unit = - exposedPublisher.takePendingSubscribers() foreach registerSubscriber + exposedPublisher.takePendingSubscribers() foreach PrimaryOutputs.addSubscriber - ////////////////////// Active state ////////////////////// + def primaryOutputsFinished(completed: Boolean): Unit + def primaryOutputsReady(): Unit - def running: Receive = LoggingReceive(downstreamManagement orElse { - case OnNext(element) ⇒ - primaryInputs.enqueueInputElement(element) - pump() - case OnComplete ⇒ - primaryInputs.complete() - flushAndComplete() - pump() - case OnError(cause) ⇒ failureReceived(cause) - }) + def pumpOutputs(): Unit - // Called by SubscriberManagement when all subscribers are gone. - // The method shutdown() is called automatically by SubscriberManagement after it called this method. - override def cancelUpstream(): Unit = { - if (primaryInputs ne null) primaryInputs.cancel() - PrimaryOutputs.cancel() +} + +/** + * INTERNAL API + */ +private[akka] abstract class ActorProcessorImpl(val settings: MaterializerSettings) + extends Actor + with ActorLogging + with SoftShutdown + with PrimaryInputs + with PrimaryOutputs + with Pump { + + val initialInputBufferSize: Int = settings.initialInputBufferSize + val maximumInputBufferSize: Int = settings.maximumInputBufferSize + val initialFanOutBufferSize: Int = settings.initialFanOutBufferSize + val maxFanOutBufferSize: Int = settings.maxFanOutBufferSize + + override def receive = waitingExposedPublisher + + override def primaryInputOnError(e: Throwable): Unit = fail(e) + override def primaryInputOnComplete(): Unit = context.become(flushing) + override def primaryInputsReady(): Unit = { + setTransferState(initialTransferState) + context.become(running) } - // Called by SubscriberManagement whenever the output buffer is ready to accept additional elements - override protected def requestFromUpstream(elements: Int): Unit = { - // FIXME: Remove debug logging - log.debug(s"received downstream demand from buffer: $elements") - PrimaryOutputs.enqueueOutputDemand(elements) + override def primaryOutputsReady(): Unit = context.become(downstreamManagement orElse waitingForUpstream) + override def primaryOutputsFinished(completed: Boolean): Unit = { + isShuttingDown = true + if (completed) + shutdownReason = None + shutdown() } - def failureReceived(e: Throwable): Unit = fail(e) - - def fail(e: Throwable): Unit = { - shutdownReason = Some(e) - log.error(e, "failure during processing") // FIXME: escalate to supervisor instead - abortDownstream(e) - if (primaryInputs ne null) primaryInputs.cancel() - exposedPublisher.shutdown(shutdownReason) - softShutdown() - } - - object PrimaryOutputs extends Outputs { - private var downstreamBufferSpace = 0 - private var downstreamCompleted = false - def demandAvailable = downstreamBufferSpace > 0 - - def enqueueOutputDemand(demand: Int): Unit = downstreamBufferSpace += demand - def enqueueOutputElement(elem: Any): Unit = { - downstreamBufferSpace -= 1 - pushToDownstream(elem) - } - - def complete(): Unit = downstreamCompleted = true - def cancel(): Unit = downstreamCompleted = true - def isClosed: Boolean = downstreamCompleted - override val NeedsDemand: TransferState = new TransferState { - def isReady = demandAvailable - def isCompleted = downstreamCompleted - } - override def NeedsDemandOrCancel: TransferState = new TransferState { - def isReady = demandAvailable || downstreamCompleted - def isCompleted = false - } - } - - lazy val needsPrimaryInputAndDemand = primaryInputs.NeedsInput && PrimaryOutputs.NeedsDemand - - private var transferState: TransferState = NotInitialized - protected def setTransferState(t: TransferState): Unit = transferState = t - protected def initialTransferState: TransferState - - // Exchange input buffer elements and output buffer "requests" until one of them becomes empty. - // Generate upstream requestMore for every Nth consumed input element - final protected def pump(): Unit = { - try while (transferState.isExecutable) { - // FIXME: Remove debug logging - log.debug(s"iterating the pump with state $transferState and buffer $bufferDebug") - transferState = withCtx(context)(transfer()) - } catch { case NonFatal(e) ⇒ fail(e) } - - // FIXME: Remove debug logging - log.debug(s"finished iterating the pump with state $transferState and buffer $bufferDebug") - - if (transferState.isCompleted) { - if (!isShuttingDown) { - // FIXME: Remove debug logging - log.debug("shutting down the pump") - if (primaryInputs.isOpen) primaryInputs.cancel() - primaryInputs.clear() - context.become(flushing) - isShuttingDown = true - } - completeDownstream() - } - } - - // Needs to be implemented by Processor implementations. Transfers elements from the input buffer to the output - // buffer. - protected def transfer(): TransferState - - ////////////////////// Completing and Flushing ////////////////////// - - protected def flushAndComplete(): Unit = context.become(flushing) + def running: Receive = LoggingReceive(downstreamManagement orElse upstreamManagement) def flushing: Receive = downstreamManagement orElse { case OnSubscribe(subscription) ⇒ throw new IllegalStateException("Cannot subscribe shutdown subscriber") case _ ⇒ // ignore everything else } + protected def fail(e: Throwable): Unit = { + shutdownReason = Some(e) + log.error(e, "failure during processing") // FIXME: escalate to supervisor instead + PrimaryOutputs.cancel(e) + shutdown() + } + + lazy val needsPrimaryInputAndDemand = primaryInputs.NeedsInput && PrimaryOutputs.NeedsDemand + + protected def initialTransferState: TransferState + + override val pumpContext = context + override def pumpInputs(): Unit = pump() + override def pumpOutputs(): Unit = pump() + + override def pumpFinished(): Unit = { + if (primaryInputs.isOpen) primaryInputs.cancel() + context.become(flushing) + PrimaryOutputs.complete() + } + override def pumpFailed(e: Throwable): Unit = fail(e) + ////////////////////// Shutdown and cleanup (graceful and abort) ////////////////////// var isShuttingDown = false - var shutdownReason: Option[Throwable] = ActorPublisher.NormalShutdownReason - // Called by SubscriberManagement to signal that output buffer finished (flushed or aborted) - override def shutdown(completed: Boolean): Unit = { - isShuttingDown = true - if (completed) - shutdownReason = None - PrimaryOutputs.complete() + def shutdown(): Unit = { + if (primaryInputs ne null) primaryInputs.cancel() exposedPublisher.shutdown(shutdownReason) softShutdown() } @@ -230,10 +200,7 @@ private[akka] abstract class ActorProcessorImpl(val settings: MaterializerSettin exposedPublisher.shutdown(shutdownReason) // Non-gracefully stopped, do our best here if (!isShuttingDown) - abortDownstream(new IllegalStateException("Processor actor terminated abruptly")) - - // FIXME what about upstream subscription before we got - // case OnSubscribe(subscription) ⇒ subscription.cancel() + PrimaryOutputs.cancel(new IllegalStateException("Processor actor terminated abruptly")) } override def preRestart(reason: Throwable, message: Option[Any]): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala index 302e725976..a3965a7f0a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala @@ -26,9 +26,7 @@ private[akka] class GroupByProcessorImpl(settings: MaterializerSettings, val key import GroupByProcessorImpl._ var keyToSubstreamOutputs = collection.mutable.Map.empty[Any, SubstreamOutputs] - var substreamPendingState: SubstreamElementState = NoPending - def substreamsFinished: Boolean = keyToSubstreamOutputs.isEmpty override def initialTransferState = needsPrimaryInputAndDemand @@ -40,7 +38,7 @@ private[akka] class GroupByProcessorImpl(settings: MaterializerSettings, val key // Just drop, we do not open any more substreams } else { val substreamOutput = newSubstream() - pushToDownstream((key, substreamOutput.processor)) + PrimaryOutputs.enqueueOutputElement((key, substreamOutput.processor)) keyToSubstreamOutputs(key) = substreamOutput substreamPendingState = PendingElement(elem, key) } @@ -53,7 +51,7 @@ private[akka] class GroupByProcessorImpl(settings: MaterializerSettings, val key val elem = primaryInputs.dequeueInputElement() val key = keyFor(elem) if (keyToSubstreamOutputs.contains(key)) { - substreamPendingState = PendingElement(elem, key) + substreamPendingState = if (keyToSubstreamOutputs(key).isOpen) PendingElement(elem, key) else NoPending } else if (PrimaryOutputs.isOpen) { substreamPendingState = PendingElementForNewStream(elem, key) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala index 6ae8783c9e..746ce1f85b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala @@ -75,10 +75,10 @@ private[akka] class RecoverProcessorImpl(_settings: MaterializerSettings, _op: A override def running: Receive = wrapInSuccess orElse super.running - override def failureReceived(e: Throwable): Unit = { + override def primaryInputOnError(e: Throwable): Unit = { primaryInputs.enqueueInputElement(Failure(e)) primaryInputs.complete() - flushAndComplete() + context.become(flushing) pump() } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala index ce92e038df..4aa5ba5ca0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala @@ -51,7 +51,7 @@ private[akka] class SplitWhenProcessorImpl(_settings: MaterializerSettings, val pendingElement = NoPending case PendingElementForNewStream(elem) ⇒ val substreamOutput = newSubstream() - pushToDownstream(substreamOutput.processor) + PrimaryOutputs.enqueueOutputElement(substreamOutput.processor) currentSubstream = substreamOutput pendingElement = PendingElement(elem) } @@ -65,7 +65,7 @@ private[akka] class SplitWhenProcessorImpl(_settings: MaterializerSettings, val override def invalidateSubstream(substream: ActorRef): Unit = { pendingElement match { - case PendingElement(_) ⇒ + case PendingElement(_) if substream == currentSubstream.substream ⇒ setTransferState(primaryInputs.NeedsInput) pendingElement = NoPending case _ ⇒ diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala index 7af5d00690..1100af4896 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamOfStreamProcessors.scala @@ -46,7 +46,10 @@ private[akka] abstract class MultiStreamOutputProcessor(_settings: MaterializerS completed = true } - override def cancel(): Unit = completed = true + override def cancel(e: Throwable): Unit = { + if (!completed) substream ! OnError(e) + completed = true + } override def enqueueOutputElement(elem: Any): Unit = { demands -= 1 @@ -72,28 +75,30 @@ private[akka] abstract class MultiStreamOutputProcessor(_settings: MaterializerS outputs } + def fullyCompleted: Boolean = isShuttingDown && isPumpFinished && context.children.isEmpty + protected def invalidateSubstream(substream: ActorRef): Unit = { substreamOutputs(substream).complete() substreamOutputs -= substream - if ((isShuttingDown || PrimaryOutputs.isClosed) && context.children.isEmpty) context.stop(self) + if (fullyCompleted) shutdown() pump() } override def fail(e: Throwable): Unit = { - context.children foreach (_ ! OnError(e)) + substreamOutputs.values foreach (_.cancel(e)) super.fail(e) } - override def shutdown(completed: Boolean): Unit = { + override def primaryOutputsFinished(completed: Boolean): Unit = { // If the master stream is cancelled (no one consumes substreams as elements from the master stream) // then this callback does not mean we are shutting down // We can only shut down after all substreams (our children) are closed - if (context.children.isEmpty) super.shutdown(completed) + if (fullyCompleted) shutdown() } - override def completeDownstream(): Unit = { + override def pumpFinished(): Unit = { context.children foreach (_ ! OnComplete) - super.completeDownstream() + super.pumpFinished() } override val downstreamManagement: Receive = super.downstreamManagement orElse { @@ -131,8 +136,10 @@ private[akka] abstract class TwoStreamInputProcessor(_settings: MaterializerSett var secondaryInputs: Inputs = _ - override def publisherExposed(): Unit = + override def primaryOutputsReady(): Unit = { other.getPublisher.subscribe(new OtherActorSubscriber(self)) + super.primaryOutputsReady() + } override def waitingForUpstream: Receive = super.waitingForUpstream orElse { case OtherStreamOnComplete ⇒ @@ -150,13 +157,13 @@ private[akka] abstract class TwoStreamInputProcessor(_settings: MaterializerSett pump() case OtherStreamOnComplete ⇒ secondaryInputs.complete() - flushAndComplete() + primaryInputOnComplete() pump() } - override def flushAndComplete(): Unit = { + override def primaryInputOnComplete(): Unit = { if (secondaryInputs.isClosed && primaryInputs.isClosed) - super.flushAndComplete() + super.primaryInputOnComplete() } override def transitionToRunningWhenReady(): Unit = if ((primaryInputs ne null) && (secondaryInputs ne null)) { @@ -169,9 +176,9 @@ private[akka] abstract class TwoStreamInputProcessor(_settings: MaterializerSett super.fail(cause) } - override def cancelUpstream(): Unit = { + override def primaryOutputsFinished(completed: Boolean) { if (secondaryInputs ne null) secondaryInputs.cancel() - super.cancelUpstream() + super.primaryOutputsFinished(completed) } } \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala index 0c089b4017..2cf0ea3f30 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala @@ -3,10 +3,15 @@ */ package akka.stream.impl -import org.reactivestreams.spi.Subscription +import org.reactivestreams.spi.{ Subscriber, Subscription } import java.util.Arrays +import scala.util.control.NonFatal +import akka.actor.ActorRefFactory -trait Inputs { +/** + * INTERNAL API + */ +private[akka] trait Inputs { def NeedsInput: TransferState def NeedsInputOrComplete: TransferState @@ -19,13 +24,29 @@ trait Inputs { def isOpen: Boolean = !isClosed def prefetch(): Unit - def clear(): Unit def inputsDepleted: Boolean def inputsAvailable: Boolean } -trait Outputs { +/** + * INTERNAL API + */ +private[akka] trait DefaultInputTransferStates extends Inputs { + override val NeedsInput: TransferState = new TransferState { + def isReady = inputsAvailable + def isCompleted = inputsDepleted + } + override val NeedsInputOrComplete: TransferState = new TransferState { + def isReady = inputsAvailable || inputsDepleted + def isCompleted = false + } +} + +/** + * INTERNAL API + */ +private[akka] trait Outputs { def NeedsDemand: TransferState def NeedsDemandOrCancel: TransferState @@ -33,13 +54,30 @@ trait Outputs { def enqueueOutputElement(elem: Any): Unit def complete(): Unit - def cancel(): Unit + def cancel(e: Throwable): Unit def isClosed: Boolean def isOpen: Boolean = !isClosed } +/** + * INTERNAL API + */ +private[akka] trait DefaultOutputTransferStates extends Outputs { + override val NeedsDemand: TransferState = new TransferState { + def isReady = demandAvailable + def isCompleted = isClosed + } + override def NeedsDemandOrCancel: TransferState = new TransferState { + def isReady = demandAvailable || isClosed + def isCompleted = false + } +} + // States of the operation that is executed by this processor -trait TransferState { +/** + * INTERNAL API + */ +private[akka] trait TransferState { def isReady: Boolean def isCompleted: Boolean def isExecutable = isReady && !isCompleted @@ -55,17 +93,26 @@ trait TransferState { } } -object Completed extends TransferState { +/** + * INTERNAL API + */ +private[akka] object Completed extends TransferState { def isReady = false def isCompleted = true } -object NotInitialized extends TransferState { +/** + * INTERNAL API + */ +private[akka] object NotInitialized extends TransferState { def isReady = false def isCompleted = false } -object EmptyInputs extends Inputs { +/** + * INTERNAL API + */ +private[akka] object EmptyInputs extends Inputs { override def inputsAvailable: Boolean = false override def inputsDepleted: Boolean = true override def isClosed: Boolean = true @@ -73,7 +120,6 @@ object EmptyInputs extends Inputs { override def complete(): Unit = () override def cancel(): Unit = () override def prefetch(): Unit = () - override def clear(): Unit = () override def dequeueInputElement(): Any = throw new UnsupportedOperationException("Cannot dequeue from EmptyInputs") override def enqueueInputElement(elem: Any): Unit = throw new UnsupportedOperationException("Cannot enqueue to EmptyInputs") @@ -85,7 +131,38 @@ object EmptyInputs extends Inputs { override val NeedsInput: TransferState = Completed } -class BatchingInputBuffer(val upstream: Subscription, val size: Int) extends Inputs { +/** + * INTERNAL API + */ +private[akka] trait Pump { + protected def pumpContext: ActorRefFactory + private var transferState: TransferState = NotInitialized + def setTransferState(t: TransferState): Unit = transferState = t + + def isPumpFinished: Boolean = transferState.isCompleted + + // Exchange input buffer elements and output buffer "requests" until one of them becomes empty. + // Generate upstream requestMore for every Nth consumed input element + final def pump(): Unit = { + try while (transferState.isExecutable) { + transferState = ActorBasedFlowMaterializer.withCtx(pumpContext)(transfer()) + } catch { case NonFatal(e) ⇒ pumpFailed(e) } + + if (isPumpFinished) pumpFinished() + } + + protected def pumpFailed(e: Throwable): Unit + protected def pumpFinished(): Unit + + // Needs to be implemented by Processor implementations. Transfers elements from the input buffer to the output + // buffer. + protected def transfer(): TransferState +} + +/** + * INTERNAL API + */ +private[akka] class BatchingInputBuffer(val upstream: Subscription, val size: Int) extends DefaultInputTransferStates { // TODO: buffer and batch sizing heuristics private var inputBuffer = Array.ofDim[AnyRef](size) private var inputBufferElements = 0 @@ -125,10 +202,11 @@ class BatchingInputBuffer(val upstream: Subscription, val size: Int) extends Inp override def cancel(): Unit = { if (!upstreamCompleted) upstream.cancel() upstreamCompleted = true + clear() } override def isClosed: Boolean = upstreamCompleted - override def clear(): Unit = { + private def clear(): Unit = { Arrays.fill(inputBuffer, 0, inputBuffer.length, null) inputBufferElements = 0 } @@ -136,12 +214,42 @@ class BatchingInputBuffer(val upstream: Subscription, val size: Int) extends Inp override def inputsDepleted = upstreamCompleted && inputBufferElements == 0 override def inputsAvailable = inputBufferElements > 0 - override val NeedsInput: TransferState = new TransferState { - def isReady = inputsAvailable - def isCompleted = inputsDepleted +} + +/** + * INTERNAL API + */ +private[akka] abstract class FanoutOutputs(val maxBufferSize: Int, val initialBufferSize: Int) extends DefaultOutputTransferStates with SubscriberManagement[Any] { + private var downstreamBufferSpace = 0 + private var downstreamCompleted = false + def demandAvailable = downstreamBufferSpace > 0 + + def enqueueOutputDemand(demand: Int): Unit = downstreamBufferSpace += demand + def enqueueOutputElement(elem: Any): Unit = { + downstreamBufferSpace -= 1 + pushToDownstream(elem) } - override val NeedsInputOrComplete: TransferState = new TransferState { - def isReady = inputsAvailable || inputsDepleted - def isCompleted = false + + def complete(): Unit = { + if (!downstreamCompleted) completeDownstream() + downstreamCompleted = true + } + + def cancel(e: Throwable): Unit = { + downstreamCompleted = true + abortDownstream(e) + } + def isClosed: Boolean = downstreamCompleted + + def handleRequest(subscription: S, elements: Int): Unit = super.moreRequested(subscription, elements) + def addSubscriber(subscriber: Subscriber[Any]): Unit = super.registerSubscriber(subscriber) + def removeSubscription(subscription: S): Unit = super.unregisterSubscription(subscription) + + def afterShutdown(completed: Boolean): Unit + + override protected def requestFromUpstream(elements: Int): Unit = enqueueOutputDemand(elements) + override protected def shutdown(completed: Boolean): Unit = afterShutdown(completed) + override protected def cancelUpstream(): Unit = { + downstreamCompleted = true } } diff --git a/akka-stream/src/main/scala/akka/stream/io/StreamIO.scala b/akka-stream/src/main/scala/akka/stream/io/StreamIO.scala new file mode 100644 index 0000000000..df36823f58 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/StreamIO.scala @@ -0,0 +1,103 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.io + +import akka.util.ByteString +import org.reactivestreams.api.{ Processor, Producer, Consumer } +import java.net.InetSocketAddress +import akka.actor._ +import scala.collection._ +import scala.concurrent.duration.FiniteDuration +import akka.io.Inet.SocketOption +import akka.io.{ IO, Tcp } +import akka.stream.impl.{ ActorPublisher, ExposedPublisher, ActorProcessor } +import akka.stream.MaterializerSettings +import akka.io.Tcp.CommandFailed +import akka.stream.io.StreamTcp.OutgoingTcpConnection + +object StreamIO { + trait Extension extends akka.actor.Extension { + def manager: ActorRef + } + + def apply[T <: Extension](key: ExtensionId[T])(implicit system: ActorSystem): ActorRef = key(system).manager + +} + +object StreamTcp extends ExtensionId[StreamTcpExt] with ExtensionIdProvider { + + override def lookup = StreamTcp + override def createExtension(system: ExtendedActorSystem): StreamTcpExt = new StreamTcpExt(system) + override def get(system: ActorSystem): StreamTcpExt = super.get(system) + + case class OutgoingTcpConnection(remoteAddress: InetSocketAddress, + localAddress: InetSocketAddress, + processor: Processor[ByteString, ByteString]) { + def outputStream: Consumer[ByteString] = processor + def inputStream: Producer[ByteString] = processor + } + + case class TcpServerBinding(localAddress: InetSocketAddress, + connectionStream: Producer[IncomingTcpConnection]) + + case class IncomingTcpConnection(remoteAddress: InetSocketAddress, + inputStream: Producer[ByteString], + outputStream: Consumer[ByteString]) { + def handleWith(processor: Processor[ByteString, ByteString]): Unit = { + processor.produceTo(outputStream) + inputStream.produceTo(processor) + } + } + + case class Connect(remoteAddress: InetSocketAddress, + localAddress: Option[InetSocketAddress] = None, + options: immutable.Traversable[SocketOption] = Nil, + timeout: Option[FiniteDuration] = None, + settings: MaterializerSettings) + + case class Bind(localAddress: InetSocketAddress, + backlog: Int = 100, + options: immutable.Traversable[SocketOption] = Nil, + settings: MaterializerSettings) + +} + +/** + * INTERNAL API + */ +private[akka] class StreamTcpExt(system: ExtendedActorSystem) extends StreamIO.Extension { + val manager: ActorRef = system.systemActorOf(Props[StreamTcpManager], name = "IO-TCP-STREAM") +} + +/** + * INTERNAL API + */ +private[akka] object StreamTcpManager { + private[akka] case class ExposedProcessor(processor: Processor[ByteString, ByteString]) +} + +/** + * INTERNAL API + */ +private[akka] class StreamTcpManager extends Actor { + import StreamTcpManager._ + + def receive: Receive = { + case StreamTcp.Connect(remoteAddress, localAddress, options, timeout, settings) ⇒ + val processorActor = context.actorOf(TcpStreamActor.outboundProps( + Tcp.Connect(remoteAddress, localAddress, options, timeout, pullMode = true), + requester = sender(), + settings)) + processorActor ! ExposedProcessor(new ActorProcessor[ByteString, ByteString](processorActor)) + + case StreamTcp.Bind(localAddress, backlog, options, settings) ⇒ + val publisherActor = context.actorOf(TcpListenStreamActor.props( + Tcp.Bind(context.system.deadLetters, localAddress, backlog, options, pullMode = true), + requester = sender(), + settings)) + publisherActor ! ExposedPublisher(new ActorPublisher(publisherActor)) + } + +} + diff --git a/akka-stream/src/main/scala/akka/stream/io/TcpConnectionStream.scala b/akka-stream/src/main/scala/akka/stream/io/TcpConnectionStream.scala new file mode 100644 index 0000000000..11b5ea7fcf --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/TcpConnectionStream.scala @@ -0,0 +1,207 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.io + +import akka.io.{ IO, Tcp } +import scala.util.control.NoStackTrace +import akka.actor.{ ActorRefFactory, Actor, Props, ActorRef } +import akka.stream.impl._ +import akka.util.ByteString +import akka.io.Tcp._ +import akka.stream.MaterializerSettings +import org.reactivestreams.api.Processor +import java.net.InetSocketAddress + +/** + * INTERNAL API + */ +private[akka] object TcpStreamActor { + case object WriteAck extends Tcp.Event + class TcpStreamException(msg: String) extends RuntimeException(msg) with NoStackTrace + + def outboundProps(connectCmd: Connect, requester: ActorRef, settings: MaterializerSettings): Props = + Props(new OutboundTcpStreamActor(connectCmd, requester, settings)) + def inboundProps(connection: ActorRef, settings: MaterializerSettings): Props = + Props(new InboundTcpStreamActor(connection, settings)) +} + +/** + * INTERNAL API + */ +private[akka] abstract class TcpStreamActor(settings: MaterializerSettings) extends Actor + with PrimaryInputs + with PrimaryOutputs { + + import TcpStreamActor._ + def connection: ActorRef + + val initialInputBufferSize: Int = settings.initialInputBufferSize + val maximumInputBufferSize: Int = settings.maximumInputBufferSize + val initialFanOutBufferSize: Int = settings.initialFanOutBufferSize + val maxFanOutBufferSize: Int = settings.maxFanOutBufferSize + + object TcpInputs extends DefaultInputTransferStates { + private var closed: Boolean = false + private var pendingElement: ByteString = null + + override def inputsAvailable: Boolean = pendingElement ne null + override def inputsDepleted: Boolean = closed && !inputsAvailable + override def prefetch(): Unit = connection ! ResumeReading + override def isClosed: Boolean = closed + override def complete(): Unit = closed = true + override def cancel(): Unit = { + closed = true + pendingElement = null + } + override def dequeueInputElement(): Any = { + val elem = pendingElement + pendingElement = null + connection ! ResumeReading + elem + } + override def enqueueInputElement(elem: Any): Unit = pendingElement = elem.asInstanceOf[ByteString] + + } + + object TcpOutputs extends DefaultOutputTransferStates { + private var closed: Boolean = false + private var pendingDemand = true + override def isClosed: Boolean = closed + override def cancel(e: Throwable): Unit = { + if (!closed) connection ! Abort + closed = true + } + override def complete(): Unit = { + if (!closed) connection ! ConfirmedClose + closed = true + } + override def enqueueOutputElement(elem: Any): Unit = { + connection ! Write(elem.asInstanceOf[ByteString], WriteAck) + pendingDemand = false + } + def enqueueDemand(): Unit = pendingDemand = true + + override def demandAvailable: Boolean = pendingDemand + } + + object WritePump extends Pump { + lazy val NeedsInputAndDemand = primaryInputs.NeedsInput && TcpOutputs.NeedsDemand + override protected def transfer(): TransferState = { + var batch = ByteString.empty + while (primaryInputs.inputsAvailable) batch ++= primaryInputs.dequeueInputElement().asInstanceOf[ByteString] + TcpOutputs.enqueueOutputElement(batch) + NeedsInputAndDemand + } + override protected def pumpFinished(): Unit = TcpOutputs.complete() + override protected def pumpFailed(e: Throwable): Unit = fail(e) + override protected def pumpContext: ActorRefFactory = context + } + + object ReadPump extends Pump { + lazy val NeedsInputAndDemand = TcpInputs.NeedsInput && PrimaryOutputs.NeedsDemand + override protected def transfer(): TransferState = { + PrimaryOutputs.enqueueOutputElement(TcpInputs.dequeueInputElement()) + NeedsInputAndDemand + } + override protected def pumpFinished(): Unit = PrimaryOutputs.complete() + override protected def pumpFailed(e: Throwable): Unit = fail(e) + override protected def pumpContext: ActorRefFactory = context + } + + override def pumpInputs(): Unit = WritePump.pump() + override def pumpOutputs(): Unit = ReadPump.pump() + + override def receive = waitingExposedPublisher + + override def primaryInputOnError(e: Throwable): Unit = fail(e) + override def primaryInputOnComplete(): Unit = shutdown() + override def primaryInputsReady(): Unit = { + connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) + ReadPump.setTransferState(ReadPump.NeedsInputAndDemand) + WritePump.setTransferState(WritePump.NeedsInputAndDemand) + TcpInputs.prefetch() + context.become(running) + } + + override def primaryOutputsReady(): Unit = context.become(downstreamManagement orElse waitingForUpstream) + override def primaryOutputsFinished(completed: Boolean): Unit = shutdown() + + val running: Receive = upstreamManagement orElse downstreamManagement orElse { + case WriteAck ⇒ + TcpOutputs.enqueueDemand() + pumpInputs() + case Received(data) ⇒ + TcpInputs.enqueueInputElement(data) + pumpOutputs() + case Closed ⇒ + TcpInputs.complete() + TcpOutputs.complete() + WritePump.pump() + ReadPump.pump() + case ConfirmedClosed ⇒ + TcpInputs.complete() + pumpOutputs() + case PeerClosed ⇒ + println("closed") + TcpInputs.complete() + pumpOutputs() + case ErrorClosed(cause) ⇒ fail(new TcpStreamException(s"The connection closed with error $cause")) + case CommandFailed(cmd) ⇒ fail(new TcpStreamException(s"Tcp command [$cmd] failed")) + case Aborted ⇒ fail(new TcpStreamException("The connection has been aborted")) + } + + def fail(e: Throwable): Unit = { + TcpInputs.cancel() + TcpOutputs.cancel(e) + if (primaryInputs ne null) primaryInputs.cancel() + PrimaryOutputs.cancel(e) + exposedPublisher.shutdown(Some(e)) + } + + def shutdown(): Unit = { + if (TcpOutputs.isClosed && PrimaryOutputs.isClosed) { + context.stop(self) + exposedPublisher.shutdown(None) + } + } + +} + +/** + * INTERNAL API + */ +private[akka] class InboundTcpStreamActor( + val connection: ActorRef, _settings: MaterializerSettings) + extends TcpStreamActor(_settings) { + +} + +/** + * INTERNAL API + */ +private[akka] class OutboundTcpStreamActor(val connectCmd: Connect, val requester: ActorRef, _settings: MaterializerSettings) + extends TcpStreamActor(_settings) { + import TcpStreamActor._ + var connection: ActorRef = _ + import context.system + + override def primaryOutputsReady(): Unit = context.become(waitingExposedProcessor) + + val waitingExposedProcessor: Receive = { + case StreamTcpManager.ExposedProcessor(processor) ⇒ + IO(Tcp) ! connectCmd + context.become(waitConnection(processor)) + case _ ⇒ throw new IllegalStateException("The second message must be ExposedProcessor") + } + + def waitConnection(exposedProcessor: Processor[ByteString, ByteString]): Receive = { + case Connected(remoteAddress, localAddress) ⇒ + connection = sender() + requester ! StreamTcp.OutgoingTcpConnection(remoteAddress, localAddress, exposedProcessor) + context.become(downstreamManagement orElse waitingForUpstream) + case f: CommandFailed ⇒ + requester ! f + fail(new TcpStreamException("Connection failed.")) + } +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/io/TcpListenStreamActor.scala b/akka-stream/src/main/scala/akka/stream/io/TcpListenStreamActor.scala new file mode 100644 index 0000000000..b8f10ab85a --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/TcpListenStreamActor.scala @@ -0,0 +1,138 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.io + +import scala.util.control.NoStackTrace +import akka.actor.{ ActorRefFactory, Actor, Props, ActorRef } +import akka.stream.MaterializerSettings +import akka.stream.impl._ +import akka.io.Tcp._ +import akka.util.ByteString +import scala.Some +import akka.stream.MaterializerSettings +import akka.io.{ IO, Tcp } +import akka.io.Tcp.Connected +import akka.io.Tcp.CommandFailed +import scala.Some +import akka.stream.MaterializerSettings +import akka.io.Tcp.ResumeAccepting +import org.reactivestreams.api.{ Consumer, Producer } +import org.reactivestreams.spi.Publisher + +/** + * INTERNAL API + */ +private[akka] object TcpListenStreamActor { + class TcpListenStreamException(msg: String) extends RuntimeException(msg) with NoStackTrace + + def props(bindCmd: Tcp.Bind, requester: ActorRef, settings: MaterializerSettings): Props = + Props(new TcpListenStreamActor(bindCmd, requester, settings)) + + case class ConnectionProducer(getPublisher: Publisher[StreamTcp.IncomingTcpConnection]) + extends Producer[StreamTcp.IncomingTcpConnection] { + + def produceTo(consumer: Consumer[StreamTcp.IncomingTcpConnection]): Unit = + getPublisher.subscribe(consumer.getSubscriber) + } +} + +/** + * INTERNAL API + */ +private[akka] class TcpListenStreamActor(bindCmd: Tcp.Bind, requester: ActorRef, settings: MaterializerSettings) extends Actor + with PrimaryOutputs with Pump { + import TcpListenStreamActor._ + import context.system + + var listener: ActorRef = _ + + override def maxFanOutBufferSize: Int = settings.maxFanOutBufferSize + override def initialFanOutBufferSize: Int = settings.initialFanOutBufferSize + + override def receive: Actor.Receive = waitingExposedPublisher + override def primaryOutputsReady(): Unit = { + IO(Tcp) ! bindCmd.copy(handler = self) + context.become(waitBound) + } + + val waitBound: Receive = { + case Bound(localAddress) ⇒ + listener = sender() + setTransferState(NeedsInputAndDemand) + IncomingConnections.prefetch() + requester ! StreamTcp.TcpServerBinding( + localAddress, + ConnectionProducer(exposedPublisher.asInstanceOf[Publisher[StreamTcp.IncomingTcpConnection]])) + context.become(running) + case f: CommandFailed ⇒ + requester ! f + fail(new TcpListenStreamException("Bind failed")) + } + + val running: Receive = downstreamManagement orElse { + case c: Connected ⇒ + IncomingConnections.enqueueInputElement((c, sender())) + pump() + case f: CommandFailed ⇒ + fail(new TcpListenStreamException(s"Command [${f.cmd}] failed")) + } + + override def pumpOutputs(): Unit = pump() + + override def primaryOutputsFinished(completed: Boolean): Unit = shutdown() + + lazy val NeedsInputAndDemand = PrimaryOutputs.NeedsDemand && IncomingConnections.NeedsInput + + override protected def transfer(): TransferState = { + val (connected, connection) = IncomingConnections.dequeueInputElement().asInstanceOf[(Connected, ActorRef)] + val tcpStreamActor = context.actorOf(TcpStreamActor.inboundProps(connection, settings)) + val processor = new ActorProcessor[ByteString, ByteString](tcpStreamActor) + PrimaryOutputs.enqueueOutputElement(StreamTcp.IncomingTcpConnection(connected.remoteAddress, processor, processor)) + NeedsInputAndDemand + } + + override protected def pumpFinished(): Unit = IncomingConnections.cancel() + override protected def pumpFailed(e: Throwable): Unit = fail(e) + override protected def pumpContext: ActorRefFactory = context + + object IncomingConnections extends DefaultInputTransferStates { + private var closed: Boolean = false + private var pendingConnection: (Connected, ActorRef) = null + + override def inputsAvailable: Boolean = pendingConnection ne null + override def inputsDepleted: Boolean = closed && !inputsAvailable + override def prefetch(): Unit = listener ! ResumeAccepting(1) + override def isClosed: Boolean = closed + override def complete(): Unit = { + if (!closed) listener ! Unbind + closed = true + } + override def cancel(): Unit = { + if (!closed) listener ! Unbind + closed = true + // pendingConnection._2 ! Abort + pendingConnection = null + } + override def dequeueInputElement(): Any = { + val elem = pendingConnection + pendingConnection = null + listener ! ResumeAccepting(1) + elem + } + override def enqueueInputElement(elem: Any): Unit = pendingConnection = elem.asInstanceOf[(Connected, ActorRef)] + + } + + def fail(e: Throwable): Unit = { + IncomingConnections.cancel() + PrimaryOutputs.cancel(e) + exposedPublisher.shutdown(Some(e)) + } + + def shutdown(): Unit = { + IncomingConnections.complete() + PrimaryOutputs.complete() + exposedPublisher.shutdown(None) + } +} diff --git a/akka-stream/src/test/scala/akka/stream/FlowConcatSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowConcatSpec.scala index e845a44606..6df050fd48 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowConcatSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowConcatSpec.scala @@ -80,12 +80,12 @@ class FlowConcatSpec extends TwoStreamsSetup { "work with one immediately failed and one nonempty producer" in { val consumer1 = setup(failedPublisher, nonemptyPublisher((1 to 4).iterator)) - consumer1.expectError(TestException) + consumer1.expectErrorOrSubscriptionFollowedByError(TestException) val consumer2 = setup(nonemptyPublisher((1 to 4).iterator), failedPublisher) val subscription2 = consumer2.expectSubscription() subscription2.requestMore(5) - consumer2.expectError(TestException) + consumer2.expectErrorOrSubscriptionFollowedByError(TestException) } "work with one delayed failed and one nonempty producer" in { diff --git a/akka-stream/src/test/scala/akka/stream/FlowDropSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowDropSpec.scala index eb0b54957a..f341a8fdf8 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowDropSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowDropSpec.scala @@ -7,7 +7,7 @@ import akka.testkit.AkkaSpec import akka.stream.testkit.ScriptedTest import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -class StreamDropSpec extends AkkaSpec with ScriptedTest { +class FlowDropSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, diff --git a/akka-stream/src/test/scala/akka/stream/FlowFilterSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowFilterSpec.scala index 4f8feabd7e..868dccab31 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowFilterSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowFilterSpec.scala @@ -4,10 +4,12 @@ package akka.stream import akka.testkit.AkkaSpec -import akka.stream.testkit.ScriptedTest +import akka.stream.testkit.{ StreamTestKit, ScriptedTest } import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } +import akka.stream.scaladsl.Flow +import akka.stream.impl.ActorBasedFlowMaterializer -class StreamFilterSpec extends AkkaSpec with ScriptedTest { +class FlowFilterSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, @@ -22,6 +24,27 @@ class StreamFilterSpec extends AkkaSpec with ScriptedTest { (1 to 50) foreach (_ ⇒ runScript(script, settings)(_.filter(_ % 2 == 0))) } + "not blow up with high request counts" in { + val gen = new ActorBasedFlowMaterializer(MaterializerSettings( + initialInputBufferSize = 1, + maximumInputBufferSize = 1, + initialFanOutBufferSize = 1, + maxFanOutBufferSize = 1), system) + + val probe = StreamTestKit.consumerProbe[Int] + Flow(Iterator.fill(1000)(0) ++ List(1)).filter(_ != 0). + toProducer(gen).produceTo(probe) + + val subscription = probe.expectSubscription() + for (_ ← 1 to 10000) { + subscription.requestMore(Int.MaxValue) + } + + probe.expectNext(1) + probe.expectComplete() + + } + } } \ No newline at end of file diff --git a/akka-stream/src/test/scala/akka/stream/FlowFoldSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowFoldSpec.scala index abc77de2a8..70e62340c8 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowFoldSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowFoldSpec.scala @@ -7,7 +7,7 @@ import akka.testkit.AkkaSpec import akka.stream.testkit.ScriptedTest import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -class StreamFoldSpec extends AkkaSpec with ScriptedTest { +class FlowFoldSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, diff --git a/akka-stream/src/test/scala/akka/stream/FlowForeachTest.scala b/akka-stream/src/test/scala/akka/stream/FlowForeachSpec.scala similarity index 100% rename from akka-stream/src/test/scala/akka/stream/FlowForeachTest.scala rename to akka-stream/src/test/scala/akka/stream/FlowForeachSpec.scala diff --git a/akka-stream/src/test/scala/akka/stream/FlowGroupBySpec.scala b/akka-stream/src/test/scala/akka/stream/FlowGroupBySpec.scala index 954580aeda..4d3fa00b92 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowGroupBySpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowGroupBySpec.scala @@ -67,7 +67,6 @@ class FlowGroupBySpec extends AkkaSpec { s1.expectNoMsg(100.millis) val s2 = StreamPuppet(getSubproducer(0)) - masterConsumer.expectNoMsg(100.millis) s2.expectNoMsg(100.millis) s2.requestMore(2) @@ -95,7 +94,6 @@ class FlowGroupBySpec extends AkkaSpec { StreamPuppet(getSubproducer(1)).cancel() val substream = StreamPuppet(getSubproducer(0)) - masterConsumer.expectNoMsg(100.millis) substream.requestMore(2) substream.expectNext(2) substream.expectNext(4) diff --git a/akka-stream/src/test/scala/akka/stream/FlowGroupedSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowGroupedSpec.scala index 3f87d63188..b509e57f3b 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowGroupedSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowGroupedSpec.scala @@ -8,7 +8,7 @@ import akka.stream.testkit.ScriptedTest import scala.collection.immutable import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -class StreamGroupedSpec extends AkkaSpec with ScriptedTest { +class FlowGroupedSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, diff --git a/akka-stream/src/test/scala/akka/stream/FlowMapSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowMapSpec.scala index 94523f393a..e19cf68275 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowMapSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowMapSpec.scala @@ -4,10 +4,12 @@ package akka.stream import akka.testkit.AkkaSpec -import akka.stream.testkit.ScriptedTest +import akka.stream.testkit.{ StreamTestKit, ScriptedTest } import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } +import akka.stream.scaladsl.Flow +import akka.stream.impl.ActorBasedFlowMaterializer -class StreamMapSpec extends AkkaSpec with ScriptedTest { +class FlowMapSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, @@ -15,6 +17,8 @@ class StreamMapSpec extends AkkaSpec with ScriptedTest { initialFanOutBufferSize = 1, maxFanOutBufferSize = 16) + val gen = new ActorBasedFlowMaterializer(settings, system) + "A Map" must { "map" in { @@ -22,6 +26,22 @@ class StreamMapSpec extends AkkaSpec with ScriptedTest { (1 to 50) foreach (_ ⇒ runScript(script, settings)(_.map(_.toString))) } + "not blow up with high request counts" in { + val probe = StreamTestKit.consumerProbe[Int] + Flow(List(1).iterator). + map(_ + 1).map(_ + 1).map(_ + 1).map(_ + 1).map(_ + 1). + toProducer(gen).produceTo(probe) + + val subscription = probe.expectSubscription() + for (_ ← 1 to 10000) { + subscription.requestMore(Int.MaxValue) + } + + probe.expectNext(6) + probe.expectComplete() + + } + } } \ No newline at end of file diff --git a/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala index eb71a6ea6c..4de417f0d3 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowTakeSpec.scala @@ -10,7 +10,7 @@ import akka.stream.impl.OnNext import akka.stream.impl.OnComplete import akka.stream.impl.RequestMore -class StreamTakeSpec extends AkkaSpec with ScriptedTest { +class FlowTakeSpec extends AkkaSpec with ScriptedTest { val settings = MaterializerSettings( initialInputBufferSize = 2, diff --git a/akka-stream/src/test/scala/akka/stream/FlowZipSpec.scala b/akka-stream/src/test/scala/akka/stream/FlowZipSpec.scala index 7fec69ea82..16249a37f4 100644 --- a/akka-stream/src/test/scala/akka/stream/FlowZipSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/FlowZipSpec.scala @@ -68,12 +68,12 @@ class FlowZipSpec extends TwoStreamsSetup { "work with one immediately failed and one nonempty producer" in { val consumer1 = setup(failedPublisher, nonemptyPublisher((1 to 4).iterator)) - consumer1.expectError(TestException) + consumer1.expectErrorOrSubscriptionFollowedByError(TestException) val consumer2 = setup(nonemptyPublisher((1 to 4).iterator), failedPublisher) val subscription2 = consumer2.expectSubscription() subscription2.requestMore(4) - consumer2.expectError(TestException) + consumer2.expectErrorOrSubscriptionFollowedByError(TestException) } "work with one delayed failed and one nonempty producer" in { diff --git a/akka-stream/src/test/scala/akka/stream/TwoStreamsSetup.scala b/akka-stream/src/test/scala/akka/stream/TwoStreamsSetup.scala index 381aa7658d..43144324d3 100644 --- a/akka-stream/src/test/scala/akka/stream/TwoStreamsSetup.scala +++ b/akka-stream/src/test/scala/akka/stream/TwoStreamsSetup.scala @@ -97,7 +97,7 @@ abstract class TwoStreamsSetup extends AkkaSpec { "work with two immediately failed producers" in { val consumer = setup(failedPublisher, failedPublisher) - consumer.expectError(TestException) + consumer.expectErrorOrSubscriptionFollowedByError(TestException) } "work with two delayed failed producers" in { diff --git a/akka-stream/src/test/scala/akka/stream/io/TcpFlowSpec.scala b/akka-stream/src/test/scala/akka/stream/io/TcpFlowSpec.scala new file mode 100644 index 0000000000..b44deb0c31 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/io/TcpFlowSpec.scala @@ -0,0 +1,377 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.io + +import akka.testkit.{ TestProbe, AkkaSpec } +import akka.io.{ Tcp, IO } +import java.nio.channels.ServerSocketChannel +import java.net.InetSocketAddress +import akka.stream.MaterializerSettings +import akka.stream.impl.{ ActorProcessor, ActorBasedFlowMaterializer } +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import akka.stream.testkit.StreamTestKit +import org.reactivestreams.api.Processor +import akka.actor.{ Props, ActorRef, Actor } +import scala.collection.immutable.Queue +import scala.concurrent.{ Future, Await } +import scala.concurrent.duration._ + +object TcpFlowSpec { + + case class ClientWrite(bytes: ByteString) + case class ClientRead(count: Int, readTo: ActorRef) + case class ClientClose(cmd: Tcp.CloseCommand) + + case object WriteAck extends Tcp.Event + + def testClientProps(connection: ActorRef): Props = Props(new TestClient(connection)) + def testServerProps(address: InetSocketAddress, probe: ActorRef): Props = Props(new TestServer(address, probe)) + + class TestClient(connection: ActorRef) extends Actor { + connection ! Tcp.Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) + + var queuedWrites = Queue.empty[ByteString] + var writePending = false + + var toRead = 0 + var readBuffer = ByteString.empty + var readTo: ActorRef = context.system.deadLetters + + var closeAfterWrite: Option[Tcp.CloseCommand] = None + + // FIXME: various close scenarios + def receive = { + case ClientWrite(bytes) if !writePending ⇒ + writePending = true + connection ! Tcp.Write(bytes, WriteAck) + case ClientWrite(bytes) ⇒ + queuedWrites = queuedWrites.enqueue(bytes) + case WriteAck if queuedWrites.nonEmpty ⇒ + val (next, remaining) = queuedWrites.dequeue + queuedWrites = remaining + connection ! Tcp.Write(next, WriteAck) + case WriteAck ⇒ + writePending = false + closeAfterWrite match { + case Some(cmd) ⇒ connection ! cmd + case None ⇒ + } + case ClientRead(count, requester) ⇒ + readTo = requester + toRead = count + connection ! Tcp.ResumeReading + case Tcp.Received(bytes) ⇒ + readBuffer ++= bytes + if (readBuffer.size >= toRead) { + readTo ! readBuffer + readBuffer = ByteString.empty + toRead = 0 + readTo = context.system.deadLetters + } else connection ! Tcp.ResumeReading + + case ClientClose(cmd) ⇒ + if (!writePending) connection ! cmd + else closeAfterWrite = Some(cmd) + } + + } + + case object ServerClose + + class TestServer(serverAddress: InetSocketAddress, probe: ActorRef) extends Actor { + import context.system + IO(Tcp) ! Tcp.Bind(self, serverAddress, pullMode = true) + var listener: ActorRef = _ + + def receive = { + case b @ Tcp.Bound(_) ⇒ + listener = sender() + listener ! Tcp.ResumeAccepting(1) + probe ! b + case Tcp.Connected(_, _) ⇒ + val handler = context.actorOf(testClientProps(sender())) + listener ! Tcp.ResumeAccepting(1) + probe ! handler + case ServerClose ⇒ + listener ! Tcp.Unbind + context.stop(self) + } + + } + +} + +class TcpFlowSpec extends AkkaSpec { + import TcpFlowSpec._ + + val genSettings = MaterializerSettings( + initialInputBufferSize = 4, + maximumInputBufferSize = 4, + initialFanOutBufferSize = 2, + maxFanOutBufferSize = 2) + + val gen = new ActorBasedFlowMaterializer(genSettings, system) + + def temporaryServerAddress: InetSocketAddress = { + val serverSocket = ServerSocketChannel.open().socket() + serverSocket.bind(new InetSocketAddress("127.0.0.1", 0)) + val address = new InetSocketAddress("127.0.0.1", serverSocket.getLocalPort) + serverSocket.close() + address + } + + class Server(val address: InetSocketAddress = temporaryServerAddress) { + val serverProbe = TestProbe() + val serverRef = system.actorOf(testServerProps(address, serverProbe.ref)) + serverProbe.expectMsgType[Tcp.Bound] + + def waitAccept(): ServerConnection = new ServerConnection(serverProbe.expectMsgType[ActorRef]) + def close(): Unit = serverRef ! ServerClose + } + + class ServerConnection(val connectionActor: ActorRef) { + val connectionProbe = TestProbe() + def write(bytes: ByteString): Unit = connectionActor ! ClientWrite(bytes) + + def read(count: Int): Unit = connectionActor ! ClientRead(count, connectionProbe.ref) + + def waitRead(): ByteString = connectionProbe.expectMsgType[ByteString] + def confirmedClose(): Unit = connectionActor ! ClientClose(Tcp.ConfirmedClose) + def close(): Unit = connectionActor ! ClientClose(Tcp.Close) + def abort(): Unit = connectionActor ! ClientClose(Tcp.Abort) + } + + class TcpReadProbe(tcpProcessor: Processor[ByteString, ByteString]) { + val consumerProbe = StreamTestKit.consumerProbe[ByteString]() + tcpProcessor.produceTo(consumerProbe) + val tcpReadSubscription = consumerProbe.expectSubscription() + + def read(count: Int): ByteString = { + var result = ByteString.empty + while (result.size < count) { + tcpReadSubscription.requestMore(1) + result ++= consumerProbe.expectNext() + } + result + } + + def close(): Unit = tcpReadSubscription.cancel() + } + + class TcpWriteProbe(tcpProcessor: Processor[ByteString, ByteString]) { + val producerProbe = StreamTestKit.producerProbe[ByteString]() + producerProbe.produceTo(tcpProcessor) + val tcpWriteSubscription = producerProbe.expectSubscription() + var demand = 0 + + def write(bytes: ByteString): Unit = { + if (demand == 0) demand += tcpWriteSubscription.expectRequestMore() + tcpWriteSubscription.sendNext(bytes) + demand -= 1 + } + + def close(): Unit = tcpWriteSubscription.sendComplete() + } + + def connect(server: Server): (Processor[ByteString, ByteString], ServerConnection) = { + val tcpProbe = TestProbe() + tcpProbe.send(StreamIO(StreamTcp), StreamTcp.Connect(server.address, settings = genSettings)) + val client = server.waitAccept() + val outgoingConnection = tcpProbe.expectMsgType[StreamTcp.OutgoingTcpConnection] + + (outgoingConnection.processor, client) + } + + def connect(serverAddress: InetSocketAddress): StreamTcp.OutgoingTcpConnection = { + val connectProbe = TestProbe() + connectProbe.send(StreamIO(StreamTcp), StreamTcp.Connect(serverAddress, settings = genSettings)) + connectProbe.expectMsgType[StreamTcp.OutgoingTcpConnection] + } + + def bind(serverAddress: InetSocketAddress = temporaryServerAddress): StreamTcp.TcpServerBinding = { + val bindProbe = TestProbe() + bindProbe.send(StreamIO(StreamTcp), StreamTcp.Bind(serverAddress, settings = genSettings)) + bindProbe.expectMsgType[StreamTcp.TcpServerBinding] + } + + def echoServer(serverAddress: InetSocketAddress = temporaryServerAddress): Future[Unit] = + Flow(bind(serverAddress).connectionStream).foreach { conn ⇒ + conn.inputStream.produceTo(conn.outputStream) + }.toFuture(gen) + + "Outgoing TCP stream" must { + + "work in the happy case" in { + val testData = ByteString(1, 2, 3, 4, 5) + + val server = new Server() + + val (tcpProcessor, serverConnection) = connect(server) + + val tcpReadProbe = new TcpReadProbe(tcpProcessor) + val tcpWriteProbe = new TcpWriteProbe(tcpProcessor) + + serverConnection.write(testData) + serverConnection.read(5) + tcpReadProbe.read(5) should be(testData) + tcpWriteProbe.write(testData) + serverConnection.waitRead() should be(testData) + + tcpWriteProbe.close() + tcpReadProbe.close() + + //client.read() should be(ByteString.empty) + server.close() + } + + "be able to write a sequence of ByteStrings" in { + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val testInput = Iterator.range(0, 256).map(ByteString(_)) + val expectedOutput = ByteString(Array.tabulate(256)(_.asInstanceOf[Byte])) + + serverConnection.read(256) + Flow(tcpProcessor).consume(gen) + + Flow(testInput).toProducer(gen).produceTo(tcpProcessor) + serverConnection.waitRead() should be(expectedOutput) + + } + + "be able to read a sequence of ByteStrings" in { + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val testInput = Iterator.range(0, 256).map(ByteString(_)) + val expectedOutput = ByteString(Array.tabulate(256)(_.asInstanceOf[Byte])) + + for (in ← testInput) serverConnection.write(in) + new TcpWriteProbe(tcpProcessor) // Just register an idle upstream + + val resultFuture = Flow(tcpProcessor).fold(ByteString.empty)((acc, in) ⇒ acc ++ in).toFuture(gen) + serverConnection.confirmedClose() + Await.result(resultFuture, 3.seconds) should be(expectedOutput) + + } + + "half close the connection when output stream is closed" in { + val testData = ByteString(1, 2, 3, 4, 5) + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val tcpWriteProbe = new TcpWriteProbe(tcpProcessor) + val tcpReadProbe = new TcpReadProbe(tcpProcessor) + + tcpWriteProbe.close() + // FIXME: expect PeerClosed on server + serverConnection.write(testData) + tcpReadProbe.read(5) should be(testData) + serverConnection.confirmedClose() + tcpReadProbe.consumerProbe.expectComplete() + } + + "stop reading when the input stream is cancelled" in { + val testData = ByteString(1, 2, 3, 4, 5) + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val tcpWriteProbe = new TcpWriteProbe(tcpProcessor) + val tcpReadProbe = new TcpReadProbe(tcpProcessor) + + tcpReadProbe.close() + // FIXME: expect PeerClosed on server + serverConnection.write(testData) + tcpReadProbe.consumerProbe.expectNoMsg(1.second) + serverConnection.read(5) + tcpWriteProbe.write(testData) + serverConnection.waitRead() should be(testData) + tcpWriteProbe.close() + } + + "keep write side open when remote half-closes" in { + val testData = ByteString(1, 2, 3, 4, 5) + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val tcpWriteProbe = new TcpWriteProbe(tcpProcessor) + val tcpReadProbe = new TcpReadProbe(tcpProcessor) + + // FIXME: here (and above tests) add a chitChat() method ensuring this works even after prior communication + // there should be a chitchat and non-chitchat version + + serverConnection.confirmedClose() + tcpReadProbe.consumerProbe.expectComplete() + + serverConnection.read(5) + tcpWriteProbe.write(testData) + serverConnection.waitRead() should be(testData) + + tcpWriteProbe.close() + // FIXME: expect closed event + } + + "shut down both streams when connection is completely closed" in { + // Client gets a PeerClosed event and does not know that the write side is also closed + val testData = ByteString(1, 2, 3, 4, 5) + val server = new Server() + val (tcpProcessor, serverConnection) = connect(server) + + val tcpWriteProbe = new TcpWriteProbe(tcpProcessor) + val tcpReadProbe = new TcpReadProbe(tcpProcessor) + + serverConnection.abort() + tcpReadProbe.consumerProbe.expectError() + tcpWriteProbe.tcpWriteSubscription.expectCancellation() + } + + "close the connection when input stream and oputput streams are closed" in { + pending + } + + } + + "TCP listen stream" must { + + "be able to implement echo" in { + + val serverAddress = temporaryServerAddress + val server = echoServer(serverAddress) + val conn = connect(serverAddress) + + val testInput = Iterator.range(0, 256).map(ByteString(_)) + val expectedOutput = ByteString(Array.tabulate(256)(_.asInstanceOf[Byte])) + + Flow(testInput).toProducer(gen).produceTo(conn.outputStream) + val resultFuture = Flow(conn.inputStream).fold(ByteString.empty)((acc, in) ⇒ acc ++ in).toFuture(gen) + + Await.result(resultFuture, 3.seconds) should be(expectedOutput) + + } + + "work with a chain of echoes" in { + + val serverAddress = temporaryServerAddress + val server = echoServer(serverAddress) + + val conn1 = connect(serverAddress) + val conn2 = connect(serverAddress) + val conn3 = connect(serverAddress) + + val testInput = Iterator.range(0, 256).map(ByteString(_)) + val expectedOutput = ByteString(Array.tabulate(256)(_.asInstanceOf[Byte])) + + Flow(testInput).toProducer(gen).produceTo(conn1.outputStream) + conn1.inputStream.produceTo(conn2.outputStream) + conn2.inputStream.produceTo(conn3.outputStream) + val resultFuture = Flow(conn3.inputStream).fold(ByteString.empty)((acc, in) ⇒ acc ++ in).toFuture(gen) + + Await.result(resultFuture, 3.seconds) should be(expectedOutput) + + } + + } + +}