Merge pull request #18750 from spray/w/18628-errors-from-ws-handler

Fix #18628 and other smallish WS changes
This commit is contained in:
Konrad Malawski 2015-10-22 12:55:43 +02:00
commit 0f99a42df9
16 changed files with 262 additions and 217 deletions

View file

@ -150,7 +150,7 @@ private[http] object HttpServerBluePrint {
// protocol routing
val protocolRouter = b.add(WebsocketSwitchRouter)
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory))
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory, log))
protocolRouter.out0 ~> http ~> protocolMerge.in0
protocolRouter.out1 ~> websocket ~> protocolMerge.in1
@ -360,7 +360,7 @@ private[http] object HttpServerBluePrint {
}
}
private class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] Unit, websocketRandomFactory: () Random) extends GraphStage[FanInShape2[ResponseRenderingOutput, ByteString, ByteString]] {
private class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] Unit, websocketRandomFactory: () Random, log: LoggingAdapter) extends GraphStage[FanInShape2[ResponseRenderingOutput, ByteString, ByteString]] {
private val httpIn = Inlet[ResponseRenderingOutput]("httpIn")
private val wsIn = Inlet[ByteString]("wsIn")
private val out = Outlet[ByteString]("out")
@ -389,7 +389,7 @@ private[http] object HttpServerBluePrint {
val frameHandler = handlerFlow match {
case Left(frameHandler) frameHandler
case Right(messageHandler)
Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory).join(messageHandler)
Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory, log = log).join(messageHandler)
}
installHandler(frameHandler)
websocketHandlerWasInstalled = true

View file

@ -152,11 +152,16 @@ object FrameEventParser {
(ByteString(buffer), newMask)
}
def parseCloseCode(data: ByteString): Option[Int] =
def parseCloseCode(data: ByteString): Option[(Int, String)] = {
def invalid(reason: String) = Some((Protocol.CloseCodes.ProtocolError, s"Peer sent illegal close frame ($reason)."))
if (data.length >= 2) {
val code = ((data(0) & 0xff) << 8) | (data(1) & 0xff)
if (Protocol.CloseCodes.isValid(code)) Some(code)
else Some(Protocol.CloseCodes.ProtocolError)
} else if (data.length == 1) Some(Protocol.CloseCodes.ProtocolError) // must be >= length 2 if not empty
val message = Utf8Decoder.decode(data.drop(2))
if (!Protocol.CloseCodes.isValid(code)) invalid(s"invalid close code '$code'")
else if (message.isFailure) invalid("close reason message is invalid UTF8")
else Some((code, message.get))
} else if (data.length == 1) invalid("close code must be length 2 but was 1") // must be >= length 2 if not empty
else None
}
}

View file

@ -106,8 +106,7 @@ private[http] object FrameHandler {
ctx.pull()
case Opcode.Close
become(WaitForPeerTcpClose)
val closeCode = FrameEventParser.parseCloseCode(data)
ctx.push(PeerClosed(closeCode))
ctx.push(PeerClosed.parse(data))
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
case other ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame"))
}
@ -137,7 +136,7 @@ private[http] object FrameHandler {
elem match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
become(WaitForPeerTcpClose)
ctx.push(PeerClosed(FrameEventParser.parseCloseCode(data)))
ctx.push(PeerClosed.parse(data))
case _ ctx.pull() // ignore all other data
}
}
@ -194,6 +193,13 @@ private[http] object FrameHandler {
final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
def isMessageEnd: Boolean = true
}
object PeerClosed {
def parse(data: ByteString): PeerClosed =
FrameEventParser.parseCloseCode(data) match {
case Some((code, reason)) PeerClosed(Some(code), reason)
case None PeerClosed(None)
}
}
sealed trait BypassEvent extends Output
final case class DirectAnswer(frame: FrameStart) extends BypassEvent

View file

@ -4,6 +4,7 @@
package akka.http.impl.engine.ws
import akka.event.LoggingAdapter
import akka.stream.scaladsl.Flow
import scala.concurrent.duration.FiniteDuration
import akka.stream.stage._
@ -17,7 +18,7 @@ import akka.http.impl.engine.ws.FrameHandler.UserHandlerErredOut
*
* INTERNAL API
*/
private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[FrameOutHandler.Input, FrameStart] {
private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration, log: LoggingAdapter) extends StatefulStage[FrameOutHandler.Input, FrameStart] {
def initial: StageState[AnyRef, FrameStart] = Idle
def closeTimeout: Timestamp = Timestamp.now + _closeTimeout
@ -43,7 +44,8 @@ private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDu
case UserHandlerCompleted
become(new WaitingForPeerCloseFrame())
ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))
case UserHandlerErredOut(ex)
case UserHandlerErredOut(e)
log.error(e, s"Websocket handler failed with ${e.getMessage}")
become(new WaitingForPeerCloseFrame())
ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.UnexpectedCondition, "internal error"))
case Tick ctx.pull() // ignore
@ -60,16 +62,21 @@ private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDu
*/
private class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case UserHandlerCompleted
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose())
ctx.push(closeFrame)
}
case UserHandlerCompleted sendOutLastFrame(ctx)
case UserHandlerErredOut(e)
log.error(e, s"Websocket handler failed while waiting for handler completion with ${e.getMessage}")
sendOutLastFrame(ctx)
case start: FrameStart ctx.push(start)
case _ ctx.pull() // ignore
}
def sendOutLastFrame(ctx: Context[FrameStart]): SyncDirective =
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose())
ctx.push(closeFrame)
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
ctx.fail(new IllegalStateException("Mustn't complete before user has completed"))
}
@ -138,6 +145,6 @@ private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDu
private[http] object FrameOutHandler {
type Input = AnyRef
def create(serverSide: Boolean, closeTimeout: FiniteDuration): Flow[Input, FrameStart, Unit] =
Flow[Input].transform(() new FrameOutHandler(serverSide, closeTimeout))
def create(serverSide: Boolean, closeTimeout: FiniteDuration, log: LoggingAdapter): Flow[Input, FrameStart, Unit] =
Flow[Input].transform(() new FrameOutHandler(serverSide, closeTimeout, log))
}

View file

@ -6,6 +6,7 @@ package akka.http.impl.engine.ws
import java.util.Random
import akka.event.LoggingAdapter
import akka.util.ByteString
import scala.concurrent.duration._
@ -30,9 +31,10 @@ private[http] object Websocket {
*/
def stack(serverSide: Boolean,
maskingRandomFactory: () Random,
closeTimeout: FiniteDuration = 3.seconds): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] =
closeTimeout: FiniteDuration = 3.seconds,
log: LoggingAdapter): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] =
masking(serverSide, maskingRandomFactory) atop
frameHandling(serverSide, closeTimeout) atop
frameHandling(serverSide, closeTimeout, log) atop
messageAPI(serverSide, closeTimeout)
/** The lowest layer that implements the binary protocol */
@ -52,10 +54,11 @@ private[http] object Websocket {
* from frames, decoding text messages, close handling, etc.
*/
def frameHandling(serverSide: Boolean = true,
closeTimeout: FiniteDuration): BidiFlow[FrameEvent, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] =
closeTimeout: FiniteDuration,
log: LoggingAdapter): BidiFlow[FrameEvent, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] =
BidiFlow.wrap(
FrameHandler.create(server = serverSide),
FrameOutHandler.create(serverSide, closeTimeout))(Keep.none)
FrameOutHandler.create(serverSide, closeTimeout, log))(Keep.none)
.named("ws-frame-handling")
/**
@ -68,7 +71,7 @@ private[http] object Websocket {
var inMessage = false
def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match {
case PeerClosed(code, reason)
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code"))
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new PeerClosedConnectionException(code.get, reason))
else if (inMessage) ctx.fail(new ProtocolException(s"Truncated message, peer closed connection in the middle of message."))
else ctx.finish()
case ActivelyCloseWithCode(code, reason)

View file

@ -4,9 +4,8 @@
package akka.http.impl.engine.ws
import akka.http.scaladsl.model.ws.WebsocketRequest
import akka.http.scaladsl.model.ws._
import scala.collection.immutable
import scala.concurrent.{ Future, Promise }
import akka.util.ByteString
@ -19,8 +18,7 @@ import akka.stream.scaladsl._
import akka.http.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade, WebsocketUpgradeResponse }
import akka.http.scaladsl.model.{ HttpHeader, HttpResponse, HttpMethods, Uri }
import akka.http.scaladsl.model.{ HttpResponse, HttpMethods }
import akka.http.scaladsl.model.headers.Host
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
@ -39,7 +37,7 @@ object WebsocketClientBlueprint {
log: LoggingAdapter): Http.WebsocketClientLayer =
(simpleTls.atopMat(handshake(request, settings, log))(Keep.right) atop
Websocket.framing atop
Websocket.stack(serverSide = false, maskingRandomFactory = settings.websocketRandomFactory)).reversed
Websocket.stack(serverSide = false, maskingRandomFactory = settings.websocketRandomFactory, log = log)).reversed
/**
* A bidi flow that injects and inspects the WS handshake and then goes out of the way. This BidiFlow

View file

@ -628,7 +628,7 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension {
JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message]
.atopMat(wsLayer)((_, s) adaptWsUpgradeResponse(s)))
private def adaptWsFlow(wsLayer: stream.scaladsl.Flow[sm.ws.Message, sm.ws.Message, Future[scaladsl.Http.WebsocketUpgradeResponse]]): Flow[Message, Message, Future[WebsocketUpgradeResponse]] =
private def adaptWsFlow(wsLayer: stream.scaladsl.Flow[sm.ws.Message, sm.ws.Message, Future[scaladsl.model.ws.WebsocketUpgradeResponse]]): Flow[Message, Message, Future[WebsocketUpgradeResponse]] =
Flow.adapt(JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message].joinMat(wsLayer)(Keep.right).mapMaterializedValue(adaptWsUpgradeResponse _))
private def adaptWsFlow[Mat](javaFlow: Flow[Message, Message, Mat]): stream.scaladsl.Flow[scaladsl.model.ws.Message, scaladsl.model.ws.Message, Mat] =
@ -637,10 +637,10 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension {
.viaMat(javaFlow.asScala)(Keep.right)
.map(_.asScala)
private def adaptWsResultTuple[T](result: (Future[scaladsl.Http.WebsocketUpgradeResponse], T)): Pair[Future[WebsocketUpgradeResponse], T] =
private def adaptWsResultTuple[T](result: (Future[scaladsl.model.ws.WebsocketUpgradeResponse], T)): Pair[Future[WebsocketUpgradeResponse], T] =
result match {
case (fut, tMat) Pair(adaptWsUpgradeResponse(fut), tMat)
}
private def adaptWsUpgradeResponse(responseFuture: Future[scaladsl.Http.WebsocketUpgradeResponse]): Future[WebsocketUpgradeResponse] =
private def adaptWsUpgradeResponse(responseFuture: Future[scaladsl.model.ws.WebsocketUpgradeResponse]): Future[WebsocketUpgradeResponse] =
responseFuture.map(WebsocketUpgradeResponse.adapt)(system.dispatcher)
}

View file

@ -0,0 +1,14 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.javadsl.model.ws
/**
* A PeerClosedConnectionException will be reported to the Websocket handler if the peer has closed the connection.
* `closeCode` and `closeReason` contain close messages as reported by the peer.
*/
trait PeerClosedConnectionException extends RuntimeException {
def closeCode: Int
def closeReason: String
}

View file

@ -6,7 +6,7 @@ package akka.http.javadsl.model.ws
import akka.http.javadsl.model.HttpResponse
import akka.http.scaladsl
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade }
import akka.http.scaladsl.model.ws.{ InvalidUpgradeResponse, ValidUpgrade }
import akka.japi.Option
/**
@ -36,12 +36,12 @@ trait WebsocketUpgradeResponse {
object WebsocketUpgradeResponse {
import akka.http.impl.util.JavaMapping.Implicits._
def adapt(scalaResponse: scaladsl.Http.WebsocketUpgradeResponse): WebsocketUpgradeResponse =
def adapt(scalaResponse: scaladsl.model.ws.WebsocketUpgradeResponse): WebsocketUpgradeResponse =
scalaResponse match {
case ValidUpgrade(response, chosen)
new WebsocketUpgradeResponse {
def isValid: Boolean = true
def response: HttpResponse = response
def response: HttpResponse = scalaResponse.response
def chosenSubprotocol: Option[String] = chosen.asJava
def invalidationReason: String =
throw new UnsupportedOperationException("invalidationReason must not be called for valid response")
@ -49,7 +49,7 @@ object WebsocketUpgradeResponse {
case InvalidUpgradeResponse(response, cause)
new WebsocketUpgradeResponse {
def isValid: Boolean = false
def response: HttpResponse = response
def response: HttpResponse = scalaResponse.response
def chosenSubprotocol: Option[String] = throw new UnsupportedOperationException("chosenSubprotocol must not be called for valid response")
def invalidationReason: String = cause
}

View file

@ -19,7 +19,7 @@ import akka.http.impl.util.{ ReadTheDocumentationException, Java6Compat, StreamU
import akka.http.impl.engine.ws.WebsocketClientBlueprint
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.ws.{ WebsocketRequest, Message }
import akka.http.scaladsl.model.ws.{ WebsocketUpgradeResponse, WebsocketRequest, Message }
import akka.http.scaladsl.util.FastFuture
import akka.japi
import akka.stream.Materializer
@ -689,15 +689,6 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
*/
final case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress)
/**
* Represents the response to a websocket upgrade request. Can either be [[ValidUpgrade]] or [[InvalidUpgradeResponse]].
*/
sealed trait WebsocketUpgradeResponse {
def response: HttpResponse
}
final case class ValidUpgrade(response: HttpResponse, chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse
final case class InvalidUpgradeResponse(response: HttpResponse, cause: String) extends WebsocketUpgradeResponse
/**
* Represents a connection pool to a specific target host and pool configuration.
*/

View file

@ -0,0 +1,14 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.scaladsl.model.ws
import akka.http.javadsl
/**
* A PeerClosedConnectionException will be reported to the Websocket handler if the peer has closed the connection.
* `closeCode` and `closeReason` contain close messages as reported by the peer.
*/
class PeerClosedConnectionException(val closeCode: Int, val closeReason: String)
extends RuntimeException(s"Peer closed connection with code $closeCode '$closeReason'") with javadsl.model.ws.PeerClosedConnectionException

View file

@ -0,0 +1,16 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.scaladsl.model.ws
import akka.http.scaladsl.model.HttpResponse
/**
* Represents the response to a websocket upgrade request. Can either be [[ValidUpgrade]] or [[InvalidUpgradeResponse]].
*/
sealed trait WebsocketUpgradeResponse {
def response: HttpResponse
}
final case class ValidUpgrade(response: HttpResponse, chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse
final case class InvalidUpgradeResponse(response: HttpResponse, cause: String) extends WebsocketUpgradeResponse

View file

@ -26,6 +26,8 @@ trait ByteStringSinkProbe {
def expectComplete(): Unit
def expectError(): Throwable
def expectError(cause: Throwable): Unit
def request(n: Long): Unit
}
object ByteStringSinkProbe {
@ -63,5 +65,7 @@ object ByteStringSinkProbe {
def expectComplete(): Unit = probe.expectComplete()
def expectError(): Throwable = probe.expectError()
def expectError(cause: Throwable): Unit = probe.expectError(cause)
def request(n: Long): Unit = probe.request(n)
}
}

View file

@ -4,7 +4,6 @@
package akka.http.impl.engine.ws
import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.util.Random
import org.scalatest.{ Matchers, FreeSpec }
@ -18,6 +17,11 @@ import Protocol.Opcode
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
import WSTestUtils._
val InvalidUtf8TwoByteSequence: ByteString = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
"The Websocket implementation should" - {
"collect messages from frames" - {
"for binary messages" - {
@ -228,7 +232,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"for a short strict message" in new ServerTestSetup {
val data = ByteString("abcdef", "ASCII")
val msg = BinaryMessage.Strict(data)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, data, fin = true)
@ -238,7 +241,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val data = ByteString("abcdefg", "ASCII")
val pub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
@ -261,7 +263,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val data = ByteString("abcdefg", "ASCII")
val pub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(pub))
netOutSub.request(7)
pushMessage(msg)
val sub = pub.expectSubscription()
@ -288,7 +289,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"for a short strict message" in new ServerTestSetup {
val text = "äbcdef"
val msg = TextMessage.Strict(text)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true)
@ -298,7 +298,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val text = "äbcd€fg"
val pub = TestPublisher.manualProbe[String]()
val msg = TextMessage(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
@ -328,7 +327,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val pub = TestPublisher.manualProbe[String]()
val msg = TextMessage(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
@ -345,7 +343,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val text = "abcdefg"
val pub = TestPublisher.manualProbe[String]()
val msg = TextMessage(Source(pub))
netOutSub.request(5)
pushMessage(msg)
val sub = pub.expectSubscription()
@ -377,14 +374,12 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
netOutSub.request(5)
expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames masking them on the client side" in new ClientTestSetup {
val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef")
pushInput(input)
netOutSub.request(5)
expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup {
@ -404,7 +399,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
val outPub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(outPub))
netOutSub.request(10)
pushMessage(msg)
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
@ -439,44 +433,35 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"don't respond to unsolicited pong frames" in new ClientTestSetup {
val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef")
pushInput(data)
netOutSub.request(5)
expectNoNetworkData()
}
}
"provide close behavior" - {
"after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup {
netInSub.expectRequest()
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
expectComplete(messageIn)
netIn.expectNoMsg(100.millis) // especially the cancellation not yet
expectNoNetworkData()
messageOutSub.sendComplete()
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"after receiving close frame without close code" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 0, fin = true, mask = Some(Random.nextInt())))
messageIn.expectComplete()
expectComplete(messageIn)
messageOutSub.sendComplete()
messageOut.sendComplete()
// especially mustn't be Procotol.CloseCodes.NoCodePresent
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup {
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
expectComplete(messageIn)
// sending another message is allowed before closing (inherently racy)
val pub = TestPublisher.manualProbe[ByteString]()
@ -492,82 +477,97 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOutSub.sendComplete()
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"after receiving regular close frame when fragmented message is still open" in {
new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
"after receiving regular close frame when fragmented message is still open" in new ServerTestSetup {
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false, mask = Some(Random.nextInt())))
val dataSource = expectBinaryMessage().dataStream
val inSubscriber = TestSubscriber.manualProbe[ByteString]()
dataSource.runWith(Sink(inSubscriber))
val inSub = inSubscriber.expectSubscription()
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false, mask = Some(Random.nextInt())))
val dataSource = expectBinaryMessage().dataStream
val inSubscriber = TestSubscriber.manualProbe[ByteString]()
dataSource.runWith(Sink(inSubscriber))
val inSub = inSubscriber.expectSubscription()
val outData = ByteString("def", "ASCII")
val mask = Random.nextInt()
pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1)
inSub.request(5)
inSubscriber.expectNext(outData)
val outData = ByteString("def", "ASCII")
val mask = Random.nextInt()
pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1)
inSub.request(5)
inSubscriber.expectNext(outData)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
// This is arguable: we could also just fail the subStream but complete the main message stream regularly.
// However, truncating an ongoing message by closing without sending a `Continuation(fin = true)` first
// could be seen as something being amiss.
expectError(messageIn)
inSubscriber.expectError()
// truncation of open message
// This is arguable: we could also just fail the subStream but complete the main message stream regularly.
// However, truncating an ongoing message by closing without sending a `Continuation(fin = true)` first
// could be seen as something being amiss.
messageIn.expectError()
inSubscriber.expectError()
// truncation of open message
// sending another message is allowed before closing (inherently racy)
// sending another message is allowed before closing (inherently racy)
val pub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val pub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
}
"after receiving error close frame" in pending
"after peer closes connection without sending a close frame" in new ServerTestSetup {
netInSub.expectRequest()
netInSub.sendComplete()
"after receiving error close frame with close code and without reason" in new ServerTestSetup {
pushInput(closeFrame(Protocol.CloseCodes.UnexpectedCondition, mask = true))
val error = expectError(messageIn).asInstanceOf[PeerClosedConnectionException]
error.closeCode shouldEqual Protocol.CloseCodes.UnexpectedCondition
error.closeReason shouldEqual ""
messageIn.expectComplete()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition)
messageOut.sendError(error)
netOut.expectComplete()
netIn.expectCancellation()
}
"after receiving error close frame with close code and with reason" in new ServerTestSetup {
pushInput(closeFrame(Protocol.CloseCodes.UnexpectedCondition, mask = true,
msg = "This alien landing came quite unexpected. Communication has been garbled."))
val error = expectError(messageIn).asInstanceOf[PeerClosedConnectionException]
error.closeCode shouldEqual Protocol.CloseCodes.UnexpectedCondition
error.closeReason shouldEqual "This alien landing came quite unexpected. Communication has been garbled."
expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition)
messageOut.sendError(error)
netOut.expectComplete()
netIn.expectCancellation()
}
"after peer closes connection without sending a close frame" in new ServerTestSetup {
netIn.expectRequest()
netIn.sendComplete()
expectComplete(messageIn)
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"when user handler closes (simple)" in new ServerTestSetup {
messageOutSub.sendComplete()
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(100.millis) // wait for peer to close regularly
expectNoNetworkData() // wait for peer to close regularly
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
expectComplete(messageIn)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"when user handler closes main stream and substream only afterwards" in new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
// send half a message
val pub = TestPublisher.manualProbe[ByteString]()
val msg = BinaryMessage(Source(pub))
@ -579,79 +579,97 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
messageOutSub.sendComplete()
messageOut.sendComplete()
expectNoNetworkData() // need to wait for substream to close
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(100.millis) // wait for peer to close regularly
expectNoNetworkData() // wait for peer to close regularly
val mask = Random.nextInt()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
expectComplete(messageIn)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"if user handler fails" in new ServerTestSetup {
messageOut.sendError(new RuntimeException("Oops, user handler failed!"))
expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition)
expectNoNetworkData() // wait for peer to close regularly
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
expectComplete(messageIn)
netOut.expectComplete()
netIn.expectCancellation()
}
"if user handler fails" in pending
"if peer closes with invalid close frame" - {
"close code outside of the valid range" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
pushInput(closeFrame(5700, mask = true))
val error = messageIn.expectError()
val error = expectError(messageIn).asInstanceOf[PeerClosedConnectionException]
error.closeCode shouldEqual Protocol.CloseCodes.ProtocolError
error.closeReason shouldEqual "Peer sent illegal close frame (invalid close code '5700')."
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"close data of size 1" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
val error = messageIn.expectError()
val error = expectError(messageIn).asInstanceOf[PeerClosedConnectionException]
error.closeCode shouldEqual Protocol.CloseCodes.ProtocolError
error.closeReason shouldEqual "Peer sent illegal close frame (close code must be length 2 but was 1)."
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"close message is invalid UTF8" in new ServerTestSetup {
pushInput(closeFrame(Protocol.CloseCodes.UnexpectedCondition, mask = true, msgBytes = InvalidUtf8TwoByteSequence))
val error = expectError(messageIn).asInstanceOf[PeerClosedConnectionException]
error.closeCode shouldEqual Protocol.CloseCodes.ProtocolError
error.closeReason shouldEqual "Peer sent illegal close frame (close reason message is invalid UTF8)."
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netIn.expectCancellation()
}
"reason is no valid utf8 data" in pending
}
"timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup {
override protected def closeTimeout: FiniteDuration = 100.millis
netInSub.expectRequest()
messageOutSub.sendComplete()
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup {
override protected def closeTimeout: FiniteDuration = 100.millis
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
expectProtocolErrorOnNetwork()
messageOutSub.sendComplete()
messageOut.sendComplete()
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
"ignore frames peer sends after close frame" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
expectComplete(messageIn)
pushInput(frameHeader(Opcode.Binary, 0, fin = true))
messageOutSub.sendComplete()
messageOut.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
netIn.expectCancellation()
}
}
"reject unexpected frames" - {
@ -712,19 +730,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
expectProtocolErrorOnNetwork()
}
"invalid utf8 encoding for single frame message" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
val data = InvalidUtf8TwoByteSequence
pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"invalid utf8 encoding for streamed frame" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
val data = InvalidUtf8TwoByteSequence
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
frameHeader(Opcode.Continuation, 2, fin = true) ++
@ -753,7 +765,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
pushInput(frameHeader(Opcode.Text, 0, fin = false))
pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data)
messageIn.expectError()
expectError(messageIn)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
@ -798,11 +810,11 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
protected def serverSide: Boolean
protected def closeTimeout: FiniteDuration = 1.second
val netIn = TestPublisher.manualProbe[ByteString]()
val netOut = TestSubscriber.manualProbe[ByteString]()
val netIn = TestPublisher.probe[ByteString]()
val netOut = ByteStringSinkProbe()
val messageIn = TestSubscriber.manualProbe[Message]
val messageOut = TestPublisher.manualProbe[Message]()
val messageIn = TestSubscriber.probe[Message]
val messageOut = TestPublisher.probe[Message]()
val messageHandler: Flow[Message, Message, Unit] =
Flow.wrap {
@ -817,60 +829,24 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
Source(netIn)
.via(printEvent("netIn"))
.transform(() new FrameEventParser)
.via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout).join(messageHandler))
.via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout, log = system.log).join(messageHandler))
.via(printEvent("frameRendererIn"))
.transform(() new FrameEventRenderer)
.via(printEvent("frameRendererOut"))
.to(Sink(netOut))
.to(netOut.sink)
.run()
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val messageOutSub = messageOut.expectSubscription()
val messageInSub = messageIn.expectSubscription()
def pushInput(data: ByteString): Unit = netIn.sendNext(data)
def pushMessage(msg: Message): Unit = messageOut.sendNext(msg)
def expectMessage(message: Message): Unit = messageIn.requestNext(message)
def expectMessage(): Message = messageIn.requestNext()
def expectBinaryMessage(): BinaryMessage = expectMessage().asInstanceOf[BinaryMessage]
def expectBinaryMessage(message: BinaryMessage): Unit = expectBinaryMessage() shouldEqual message
def expectTextMessage(): TextMessage = expectMessage().asInstanceOf[TextMessage]
def expectTextMessage(message: TextMessage): Unit = expectTextMessage() shouldEqual message
final def expectNetworkData(bytes: Int): ByteString = netOut.expectBytes(bytes)
def pushInput(data: ByteString): Unit = {
// TODO: expect/handle request?
netInSub.sendNext(data)
}
def pushMessage(msg: Message): Unit = {
messageOutSub.sendNext(msg)
}
def expectMessage(message: Message): Unit = {
messageInSub.request(1)
messageIn.expectNext(message)
}
def expectMessage(): Message = {
messageInSub.request(1)
messageIn.expectNext()
}
def expectBinaryMessage(): BinaryMessage =
expectMessage().asInstanceOf[BinaryMessage]
def expectBinaryMessage(message: BinaryMessage): Unit =
expectBinaryMessage() shouldEqual message
def expectTextMessage(): TextMessage =
expectMessage().asInstanceOf[TextMessage]
def expectTextMessage(message: TextMessage): Unit =
expectTextMessage() shouldEqual message
var inBuffer = ByteString.empty
@tailrec final def expectNetworkData(bytes: Int): ByteString =
if (inBuffer.size >= bytes) {
val res = inBuffer.take(bytes)
inBuffer = inBuffer.drop(bytes)
res
} else {
netOutSub.request(1)
inBuffer ++= netOut.expectNext()
expectNetworkData(bytes)
}
def expectNetworkData(data: ByteString): Unit =
expectNetworkData(data.size) shouldEqual data
def expectNetworkData(data: ByteString): Unit = expectNetworkData(data.size) shouldEqual data
def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
expectFrameHeaderOnNetwork(opcode, data.size, fin)
@ -944,8 +920,16 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
code shouldEqual expectedCode
}
def expectNoNetworkData(): Unit =
netOut.expectNoMsg(100.millis)
def expectNoNetworkData(): Unit = netOut.expectNoBytes(100.millis)
def expectComplete[T](probe: TestSubscriber.Probe[T]): Unit = {
probe.ensureSubscription()
probe.expectComplete()
}
def expectError[T](probe: TestSubscriber.Probe[T]): Throwable = {
probe.ensureSubscription()
probe.expectError()
}
}
val trace = false // set to `true` for debugging purposes

View file

@ -51,11 +51,14 @@ object WSTestUtils {
} else
frameHeader(opcode, data.size, fin, mask = None) ++ data
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
frame(Opcode.Close, closeFrameData(closeCode), fin = true, mask)
def closeFrame(closeCode: Int, mask: Boolean, msg: String = ""): ByteString =
closeFrame(closeCode, mask, ByteString(msg, "UTF-8"))
def closeFrameData(closeCode: Int): ByteString =
shortBE(closeCode)
def closeFrame(closeCode: Int, mask: Boolean, msgBytes: ByteString): ByteString =
frame(Opcode.Close, closeFrameData(closeCode, msgBytes), fin = true, mask)
def closeFrameData(closeCode: Int, msgBytes: ByteString = ByteString.empty): ByteString =
shortBE(closeCode) ++ msgBytes
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)

View file

@ -6,7 +6,7 @@ package akka.http.impl.engine.ws
import java.util.Random
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, WebsocketUpgradeResponse }
import akka.http.scaladsl.model.ws.{ InvalidUpgradeResponse, WebsocketUpgradeResponse }
import scala.concurrent.duration._
@ -14,7 +14,7 @@ import akka.http.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.{ ProductVersion, `User-Agent` }
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl.model.{ HttpResponse, Uri }
import akka.http.scaladsl.model.Uri
import akka.stream.io._
import akka.stream.scaladsl._
import akka.stream.testkit.{ TestSubscriber, TestPublisher }