diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala index ed29f99736..e1951d8323 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala @@ -10,7 +10,7 @@ import scala.concurrent.duration.FiniteDuration import akka.stream.stage._ import akka.http.impl.util.Timestamp -import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer } +import akka.http.impl.engine.ws.FrameHandler._ import Websocket.Tick /** diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala new file mode 100644 index 0000000000..65c5548f13 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import akka.http.impl.engine.ws.Protocol.Opcode +import akka.http.impl.engine.ws.WSTestUtils._ +import akka.util.ByteString +import org.scalatest.Matchers + +import scala.annotation.tailrec +import scala.util.Random + +trait WSTestSetupBase extends Matchers { + def send(bytes: ByteString): Unit + def expectNextChunk(): ByteString + + def sendWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Boolean = false, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = { + val (theMask, theData) = + if (mask) { + val m = Random.nextInt() + (Some(m), maskedBytes(data, m)._1) + } else (None, data) + send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData) + } + + def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + send(closeFrame(closeCode, mask)) + + def expectWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = + expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data + + def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + expectNextChunk() shouldEqual closeFrame(closeCode, mask) + + var inBuffer = ByteString.empty + @tailrec final def expectNetworkData(bytes: Int): ByteString = + if (inBuffer.size >= bytes) { + val res = inBuffer.take(bytes) + inBuffer = inBuffer.drop(bytes) + res + } else { + inBuffer ++= expectNextChunk() + expectNetworkData(bytes) + } + + def expectNetworkData(data: ByteString): Unit = + expectNetworkData(data.size) shouldEqual data + + def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + expectFrameHeaderOnNetwork(opcode, data.size, fin) + expectNetworkData(data) + } + def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin) + val masked = maskedBytes(data, mask)._1 + expectNetworkData(masked) + } + + def expectMaskedCloseFrame(closeCode: Int): Unit = + expectMaskedFrameOnNetwork(Protocol.Opcode.Close, closeFrameData(closeCode), fin = true) + + /** Returns the mask if any is available */ + def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = { + val (op, l, f, m) = expectFrameHeaderOnNetwork() + op shouldEqual opcode + l shouldEqual length + f shouldEqual fin + m + } + def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = { + val header = expectNetworkData(2) + + val fin = (header(0) & Protocol.FIN_MASK) != 0 + val op = header(0) & Protocol.OP_MASK + + val hasMask = (header(1) & Protocol.MASK_MASK) != 0 + val length7 = header(1) & Protocol.LENGTH_MASK + val length = length7 match { + case 126 ⇒ + val length16Bytes = expectNetworkData(2) + (length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0 + case 127 ⇒ + val length64Bytes = expectNetworkData(8) + (length64Bytes(0) & 0xff).toLong << 56 | + (length64Bytes(1) & 0xff).toLong << 48 | + (length64Bytes(2) & 0xff).toLong << 40 | + (length64Bytes(3) & 0xff).toLong << 32 | + (length64Bytes(4) & 0xff).toLong << 24 | + (length64Bytes(5) & 0xff).toLong << 16 | + (length64Bytes(6) & 0xff).toLong << 8 | + (length64Bytes(7) & 0xff).toLong << 0 + case x ⇒ x + } + val mask = + if (hasMask) { + val maskBytes = expectNetworkData(4) + val mask = + (maskBytes(0) & 0xff) << 24 | + (maskBytes(1) & 0xff) << 16 | + (maskBytes(2) & 0xff) << 8 | + (maskBytes(3) & 0xff) << 0 + Some(mask) + } else None + + (Opcode.forCode(op.toByte), length, fin, mask) + } +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala index 47f9556f2a..1821beae2b 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala @@ -43,14 +43,19 @@ object WSTestUtils { val lengthByte = lengthByteComponent | maskMask ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes } - def closeFrame(closeCode: Int, mask: Boolean): ByteString = + def frame(opcode: Opcode, data: ByteString, fin: Boolean, mask: Boolean): ByteString = if (mask) { val mask = Random.nextInt() - frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ - maskedBytes(shortBE(closeCode), mask)._1 + frameHeader(opcode, data.size, fin, mask = Some(mask)) ++ + maskedBytes(data, mask)._1 } else - frameHeader(Opcode.Close, 2, fin = true) ++ - shortBE(closeCode) + frameHeader(opcode, data.size, fin, mask = None) ++ data + + def closeFrame(closeCode: Int, mask: Boolean): ByteString = + frame(Opcode.Close, closeFrameData(closeCode), fin = true, mask) + + def closeFrameData(closeCode: Int): ByteString = + shortBE(closeCode) def maskedASCII(str: String, mask: Int): (ByteString, Int) = FrameEventParser.mask(ByteString(str, "ASCII"), mask) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala index 5a7a507099..fd5e297970 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -55,13 +55,9 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp | |""".stripMarginWithNewline("\r\n") - expectWSFrame(Protocol.Opcode.Text, - - ByteString("Message 1"), fin = true) - expectWSFrame(Protocol. - Opcode.Text, ByteString("Message 2"), fin = true) - expectWSFrame( - Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) expectWSCloseFrame(Protocol.CloseCodes.Regular) @@ -131,42 +127,13 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp } } - class TestSetup extends HttpServerTestSetupBase { + class TestSetup extends HttpServerTestSetupBase with WSTestSetupBase { implicit def system = spec.system implicit def materializer = spec.materializer - def sendWSFrame(opcode: Opcode, - data: ByteString, - fin: Boolean, - mask: Boolean = false, - rsv1: Boolean = false, - rsv2: Boolean = false, - rsv3: Boolean = false): Unit = { - val (theMask, theData) = - if (mask) { - val m = Random.nextInt() - (Some(m), maskedBytes(data, m)._1) - } else (None, data) - send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData) - } - - def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = - send(closeFrame(closeCode, mask)) - def expectNextChunk(): ByteString = { netOutSub.request(1) netOut.expectNext() } - - def expectWSFrame(opcode: Opcode, - data: ByteString, - fin: Boolean, - mask: Option[Int] = None, - rsv1: Boolean = false, - rsv2: Boolean = false, - rsv3: Boolean = false): Unit = - expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data - def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = - expectNextChunk() shouldEqual closeFrame(closeCode, mask) } }