=str fix race condition in case of early termination of connections source

This commit is contained in:
Konrad Malawski 2015-11-30 12:49:12 +01:00
parent f8fa5978d9
commit 71740f3fcd
4 changed files with 51 additions and 23 deletions

View file

@ -22,7 +22,7 @@ import akka.testkit.EventFilter
class TcpSpec extends AkkaSpec(
"""
|akka.io.tcp.windows-connection-abort-workaround-enabled=auto
|akka.stream.materializer.subscription-timeout.timeout = 3s""".stripMargin) with TcpHelper {
|akka.stream.materializer.subscription-timeout.timeout = 2s""".stripMargin) with TcpHelper {
var demand = 0L
"Outgoing TCP stream" must {
@ -374,11 +374,11 @@ class TcpSpec extends AkkaSpec(
conn.flow.join(Flow[ByteString]).run()
})(Keep.left).run(), 3.seconds)
val result = Source(immutable.Iterable.fill(10000)(ByteString(0)))
val result = Source(immutable.Iterable.fill(1000)(ByteString(0)))
.via(Tcp().outgoingConnection(serverAddress, halfClose = true))
.runFold(0)(_ + _.size)
Await.result(result, 3.seconds) should ===(10000)
Await.result(result, 3.seconds) should ===(1000)
binding.unbind()
}
@ -498,7 +498,10 @@ class TcpSpec extends AkkaSpec(
"not shut down connections after the connection stream cancelled" in assertAllStagesStopped {
val address = temporaryServerAddress()
Tcp().bind(address.getHostName, address.getPort).take(1).runForeach(_.flow.join(Flow[ByteString]).run())
Tcp().bind(address.getHostName, address.getPort).take(1).runForeach { tcp
Thread.sleep(1000) // we're testing here to see if it survives such race
tcp.flow.join(Flow[ByteString]).run()
}
val total = Source(immutable.Iterable.fill(1000)(ByteString(0)))
.via(Tcp().outgoingConnection(address))
@ -507,7 +510,7 @@ class TcpSpec extends AkkaSpec(
Await.result(total, 3.seconds) should ===(1000)
}
"shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped {
"xoxoxo shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped {
val address = temporaryServerAddress()
val firstClientConnected = Promise[Unit]()
val takeTwoAndDropSecond = Flow[IncomingConnection].map(conn {
@ -518,20 +521,19 @@ class TcpSpec extends AkkaSpec(
.via(takeTwoAndDropSecond)
.runForeach(_.flow.join(Flow[ByteString]).run())
val folder = Source(immutable.Iterable.fill(1000)(ByteString(0)))
val folder = Source(immutable.Iterable.fill(100)(ByteString(0)))
.via(Tcp().outgoingConnection(address))
.fold(0)(_ + _.size).toMat(Sink.head)(Keep.right)
val total = folder.run()
awaitAssert(firstClientConnected.future, 2.seconds)
val rejected = folder.run()
Await.result(total, 3.seconds) should ===(1000)
Await.result(total, 10.seconds) should ===(100)
a[StreamTcpException] should be thrownBy {
Await.result(rejected, 5.seconds) should ===(1000)
Await.result(rejected, 5.seconds) should ===(100)
}
}
}

View file

@ -4,7 +4,7 @@
package akka.stream.impl.io
import java.net.InetSocketAddress
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.{ AtomicLong, AtomicBoolean }
import akka.actor.{ ActorRef, Terminated }
import akka.dispatch.ExecutionContexts
@ -32,13 +32,15 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef,
val backlog: Int,
val options: immutable.Traversable[SocketOption],
val halfClose: Boolean,
val idleTimeout: Duration)
val idleTimeout: Duration,
val bindShutdownTimeout: FiniteDuration)
extends GraphStageWithMaterializedValue[SourceShape[StreamTcp.IncomingConnection], Future[StreamTcp.ServerBinding]] {
import ConnectionSourceStage._
val out: Outlet[StreamTcp.IncomingConnection] = Outlet("IncomingConnections.out")
val shape: SourceShape[StreamTcp.IncomingConnection] = SourceShape(out)
private val BindTimer = "BindTimer"
private val connectionFlowsAwaitingInitialization = new AtomicLong()
// TODO: Timeout on bind
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[ServerBinding]) = {
@ -72,9 +74,10 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef,
case c: Connected
push(out, connectionFor(c, sender))
case Unbind
if (!isClosed(out) && (listener ne null)) listener ! Unbind
if (!isClosed(out) && (listener ne null)) tryUnbind()
case Unbound // If we're unbound then just shut down
completeStage()
if (connectionFlowsAwaitingInitialization.get() == 0) completeStage()
else scheduleOnce(BindShutdownTimer, bindShutdownTimeout)
case Terminated(ref) if ref == listener
failStage(new IllegalStateException("IO Listener actor terminated unexpectedly"))
}
@ -89,10 +92,19 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef,
override def onDownstreamFinish(): Unit = tryUnbind()
})
// because when we tryUnbind, we must wait for the Ubound signal before terminating
override def keepGoingAfterAllPortsClosed = true
private def connectionFor(connected: Connected, connection: ActorRef): StreamTcp.IncomingConnection = {
connectionFlowsAwaitingInitialization.incrementAndGet()
val tcpFlow =
Flow.fromGraph(new IncomingConnectionStage(connection, connected.remoteAddress, halfClose))
.via(new Detacher[ByteString]) // must read ahead for proper completions
.mapMaterializedValue { m
connectionFlowsAwaitingInitialization.decrementAndGet()
m
}
// FIXME: Previous code was wrong, must add new tests
val handler = idleTimeout match {
@ -107,8 +119,15 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef,
}
private def tryUnbind(): Unit = {
if (listener ne null) listener ! Unbind
else completeStage()
if (listener ne null) {
self.unwatch(listener)
listener ! Unbind
}
}
override def onTimer(timerKey: Any): Unit = timerKey match {
case BindShutdownTimer
completeStage() // TODO need to manually shut down instead right?
}
override def postStop(): Unit = {
@ -122,6 +141,11 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef,
}
private[stream] object ConnectionSourceStage {
val BindTimer = "BindTimer"
val BindShutdownTimer = "BindTimer"
}
/**
* INTERNAL API
*/
@ -194,11 +218,9 @@ private[stream] object TcpConnectionStage {
val sender = evt._1
val msg = evt._2
msg match {
case Terminated(_)
failStage(new StreamTcpException("The connection actor has terminated. Stopping now."))
case Terminated(_) failStage(new StreamTcpException("The connection actor has terminated. Stopping now."))
case CommandFailed(cmd) failStage(new StreamTcpException(s"Tcp command [$cmd] failed"))
case ErrorClosed(cause) failStage(new StreamTcpException(s"The connection closed with error $cause"))
case ErrorClosed(cause) failStage(new StreamTcpException(s"The connection closed with error: $cause"))
case Aborted failStage(new StreamTcpException("The connection has been aborted"))
case Closed completeStage()
case ConfirmedClosed completeStage()
@ -253,7 +275,7 @@ private[stream] object TcpConnectionStage {
case Outbound(_, _, localAddressPromise, _)
// Fail if has not been completed with an address eariler
localAddressPromise.tryFailure(new StreamTcpException("Connection failed."))
case _
case _ // do nothing...
}
}
}
@ -299,7 +321,6 @@ private[stream] class OutgoingConnectionStage(manager: ActorRef,
val shape: FlowShape[ByteString, ByteString] = FlowShape(bytesIn, bytesOut)
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[StreamTcp.OutgoingConnection]) = {
// FIXME: A method like this would make soo much sense on Duration (i.e. toOption)
val connTimeout = connectTimeout match {
case x: FiniteDuration Some(x)

View file

@ -62,6 +62,9 @@ object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider {
class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension {
import Tcp._
// TODO maybe this should be a new setting, like `akka.stream.tcp.bind.timeout` / `shutdown-timeout` instead?
val bindShutdownTimeout = ActorMaterializer()(system).settings.subscriptionTimeoutSettings.timeout
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`.
*
@ -95,7 +98,8 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension {
backlog,
options,
halfClose,
idleTimeout))
idleTimeout,
bindShutdownTimeout))
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`

View file

@ -200,6 +200,7 @@ object GraphStageLogic {
protected def sendTerminated(): Unit = {
val watchedBy = _watchedBy.getAndSet(StageTerminatedTombstone)
if (!(watchedBy == StageTerminatedTombstone) && !watchedBy.isEmpty) {
watchedBy foreach sendTerminated(ifLocal = false)
watchedBy foreach sendTerminated(ifLocal = true)
}