From 71740f3fcdac0d38bddfbfc2118762f78038fd72 Mon Sep 17 00:00:00 2001 From: Konrad Malawski Date: Mon, 30 Nov 2015 12:49:12 +0100 Subject: [PATCH] =str fix race condition in case of early termination of connections source --- .../test/scala/akka/stream/io/TcpSpec.scala | 20 ++++---- .../scala/akka/stream/impl/io/TcpStages.scala | 47 ++++++++++++++----- .../main/scala/akka/stream/scaladsl/Tcp.scala | 6 ++- .../scala/akka/stream/stage/GraphStage.scala | 1 + 4 files changed, 51 insertions(+), 23 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index 370f74656b..f94e35e9a5 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -22,7 +22,7 @@ import akka.testkit.EventFilter class TcpSpec extends AkkaSpec( """ |akka.io.tcp.windows-connection-abort-workaround-enabled=auto - |akka.stream.materializer.subscription-timeout.timeout = 3s""".stripMargin) with TcpHelper { + |akka.stream.materializer.subscription-timeout.timeout = 2s""".stripMargin) with TcpHelper { var demand = 0L "Outgoing TCP stream" must { @@ -374,11 +374,11 @@ class TcpSpec extends AkkaSpec( conn.flow.join(Flow[ByteString]).run() })(Keep.left).run(), 3.seconds) - val result = Source(immutable.Iterable.fill(10000)(ByteString(0))) + val result = Source(immutable.Iterable.fill(1000)(ByteString(0))) .via(Tcp().outgoingConnection(serverAddress, halfClose = true)) .runFold(0)(_ + _.size) - Await.result(result, 3.seconds) should ===(10000) + Await.result(result, 3.seconds) should ===(1000) binding.unbind() } @@ -498,7 +498,10 @@ class TcpSpec extends AkkaSpec( "not shut down connections after the connection stream cancelled" in assertAllStagesStopped { val address = temporaryServerAddress() - Tcp().bind(address.getHostName, address.getPort).take(1).runForeach(_.flow.join(Flow[ByteString]).run()) + Tcp().bind(address.getHostName, address.getPort).take(1).runForeach { tcp ⇒ + Thread.sleep(1000) // we're testing here to see if it survives such race + tcp.flow.join(Flow[ByteString]).run() + } val total = Source(immutable.Iterable.fill(1000)(ByteString(0))) .via(Tcp().outgoingConnection(address)) @@ -507,7 +510,7 @@ class TcpSpec extends AkkaSpec( Await.result(total, 3.seconds) should ===(1000) } - "shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped { + "xoxoxo shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped { val address = temporaryServerAddress() val firstClientConnected = Promise[Unit]() val takeTwoAndDropSecond = Flow[IncomingConnection].map(conn ⇒ { @@ -518,20 +521,19 @@ class TcpSpec extends AkkaSpec( .via(takeTwoAndDropSecond) .runForeach(_.flow.join(Flow[ByteString]).run()) - val folder = Source(immutable.Iterable.fill(1000)(ByteString(0))) + val folder = Source(immutable.Iterable.fill(100)(ByteString(0))) .via(Tcp().outgoingConnection(address)) .fold(0)(_ + _.size).toMat(Sink.head)(Keep.right) val total = folder.run() awaitAssert(firstClientConnected.future, 2.seconds) - val rejected = folder.run() - Await.result(total, 3.seconds) should ===(1000) + Await.result(total, 10.seconds) should ===(100) a[StreamTcpException] should be thrownBy { - Await.result(rejected, 5.seconds) should ===(1000) + Await.result(rejected, 5.seconds) should ===(100) } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index 30db86af67..2b091252ef 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -4,7 +4,7 @@ package akka.stream.impl.io import java.net.InetSocketAddress -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{ AtomicLong, AtomicBoolean } import akka.actor.{ ActorRef, Terminated } import akka.dispatch.ExecutionContexts @@ -32,13 +32,15 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, val backlog: Int, val options: immutable.Traversable[SocketOption], val halfClose: Boolean, - val idleTimeout: Duration) + val idleTimeout: Duration, + val bindShutdownTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[StreamTcp.IncomingConnection], Future[StreamTcp.ServerBinding]] { + import ConnectionSourceStage._ val out: Outlet[StreamTcp.IncomingConnection] = Outlet("IncomingConnections.out") val shape: SourceShape[StreamTcp.IncomingConnection] = SourceShape(out) - private val BindTimer = "BindTimer" + private val connectionFlowsAwaitingInitialization = new AtomicLong() // TODO: Timeout on bind override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[ServerBinding]) = { @@ -72,9 +74,10 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, case c: Connected ⇒ push(out, connectionFor(c, sender)) case Unbind ⇒ - if (!isClosed(out) && (listener ne null)) listener ! Unbind + if (!isClosed(out) && (listener ne null)) tryUnbind() case Unbound ⇒ // If we're unbound then just shut down - completeStage() + if (connectionFlowsAwaitingInitialization.get() == 0) completeStage() + else scheduleOnce(BindShutdownTimer, bindShutdownTimeout) case Terminated(ref) if ref == listener ⇒ failStage(new IllegalStateException("IO Listener actor terminated unexpectedly")) } @@ -89,10 +92,19 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, override def onDownstreamFinish(): Unit = tryUnbind() }) + // because when we tryUnbind, we must wait for the Ubound signal before terminating + override def keepGoingAfterAllPortsClosed = true + private def connectionFor(connected: Connected, connection: ActorRef): StreamTcp.IncomingConnection = { + connectionFlowsAwaitingInitialization.incrementAndGet() + val tcpFlow = Flow.fromGraph(new IncomingConnectionStage(connection, connected.remoteAddress, halfClose)) .via(new Detacher[ByteString]) // must read ahead for proper completions + .mapMaterializedValue { m ⇒ + connectionFlowsAwaitingInitialization.decrementAndGet() + m + } // FIXME: Previous code was wrong, must add new tests val handler = idleTimeout match { @@ -107,8 +119,15 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, } private def tryUnbind(): Unit = { - if (listener ne null) listener ! Unbind - else completeStage() + if (listener ne null) { + self.unwatch(listener) + listener ! Unbind + } + } + + override def onTimer(timerKey: Any): Unit = timerKey match { + case BindShutdownTimer ⇒ + completeStage() // TODO need to manually shut down instead right? } override def postStop(): Unit = { @@ -122,6 +141,11 @@ private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, } +private[stream] object ConnectionSourceStage { + val BindTimer = "BindTimer" + val BindShutdownTimer = "BindTimer" +} + /** * INTERNAL API */ @@ -194,11 +218,9 @@ private[stream] object TcpConnectionStage { val sender = evt._1 val msg = evt._2 msg match { - case Terminated(_) ⇒ - failStage(new StreamTcpException("The connection actor has terminated. Stopping now.")) + case Terminated(_) ⇒ failStage(new StreamTcpException("The connection actor has terminated. Stopping now.")) case CommandFailed(cmd) ⇒ failStage(new StreamTcpException(s"Tcp command [$cmd] failed")) - - case ErrorClosed(cause) ⇒ failStage(new StreamTcpException(s"The connection closed with error $cause")) + case ErrorClosed(cause) ⇒ failStage(new StreamTcpException(s"The connection closed with error: $cause")) case Aborted ⇒ failStage(new StreamTcpException("The connection has been aborted")) case Closed ⇒ completeStage() case ConfirmedClosed ⇒ completeStage() @@ -253,7 +275,7 @@ private[stream] object TcpConnectionStage { case Outbound(_, _, localAddressPromise, _) ⇒ // Fail if has not been completed with an address eariler localAddressPromise.tryFailure(new StreamTcpException("Connection failed.")) - case _ ⇒ + case _ ⇒ // do nothing... } } } @@ -299,7 +321,6 @@ private[stream] class OutgoingConnectionStage(manager: ActorRef, val shape: FlowShape[ByteString, ByteString] = FlowShape(bytesIn, bytesOut) override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[StreamTcp.OutgoingConnection]) = { - // FIXME: A method like this would make soo much sense on Duration (i.e. toOption) val connTimeout = connectTimeout match { case x: FiniteDuration ⇒ Some(x) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala index 37a1a2b681..0a182d4828 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala @@ -62,6 +62,9 @@ object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider { class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { import Tcp._ + // TODO maybe this should be a new setting, like `akka.stream.tcp.bind.timeout` / `shutdown-timeout` instead? + val bindShutdownTimeout = ActorMaterializer()(system).settings.subscriptionTimeoutSettings.timeout + /** * Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`. * @@ -95,7 +98,8 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { backlog, options, halfClose, - idleTimeout)) + idleTimeout, + bindShutdownTimeout)) /** * Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint` diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 57aab1a543..32f845c159 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -200,6 +200,7 @@ object GraphStageLogic { protected def sendTerminated(): Unit = { val watchedBy = _watchedBy.getAndSet(StageTerminatedTombstone) if (!(watchedBy == StageTerminatedTombstone) && !watchedBy.isEmpty) { + watchedBy foreach sendTerminated(ifLocal = false) watchedBy foreach sendTerminated(ifLocal = true) }