From 2cf1c41eefe5f5899a1d34b1a3f1155eb7e76135 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Wed, 18 Feb 2015 17:00:33 +0100 Subject: [PATCH 01/13] +htc #16887 implement Websocket header parsing/rendering --- .../model/headers/WebsocketExtension.scala | 24 +++++ .../akka/http/model/headers/headers.scala | 100 ++++++++++++++++++ .../akka/http/model/parser/HeaderParser.scala | 8 +- .../http/model/parser/WebsocketHeaders.scala | 61 +++++++++++ .../main/scala/akka/http/util/Rendering.scala | 3 + .../http/model/parser/HttpHeaderSpec.scala | 38 +++++++ 6 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 akka-http-core/src/main/scala/akka/http/model/headers/WebsocketExtension.scala create mode 100644 akka-http-core/src/main/scala/akka/http/model/parser/WebsocketHeaders.scala diff --git a/akka-http-core/src/main/scala/akka/http/model/headers/WebsocketExtension.scala b/akka-http-core/src/main/scala/akka/http/model/headers/WebsocketExtension.scala new file mode 100644 index 0000000000..5fcdf6bd8e --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/model/headers/WebsocketExtension.scala @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.model.headers + +import akka.http.util.{ Rendering, ValueRenderable } + +import scala.collection.immutable + +/** + * A websocket extension as defined in http://tools.ietf.org/html/rfc6455#section-4.3 + */ +final case class WebsocketExtension(name: String, params: immutable.Map[String, String] = Map.empty) extends ValueRenderable { + def render[R <: Rendering](r: R): r.type = { + r ~~ name + if (params.nonEmpty) + params.foreach { + case (k, "") ⇒ r ~~ "; " ~~ k + case (k, v) ⇒ r ~~ "; " ~~ k ~~ '=' ~~# v + } + r + } +} diff --git a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala index 383489bf44..0f0e29eade 100644 --- a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala @@ -7,7 +7,10 @@ package headers import java.lang.Iterable import java.net.InetSocketAddress +import java.security.MessageDigest import java.util +import akka.parboiled2.util.Base64 + import scala.annotation.tailrec import scala.collection.immutable import akka.http.util._ @@ -563,6 +566,103 @@ final case class Referer(uri: Uri) extends japi.headers.Referer with ModeledHead def getUri: akka.http.model.japi.Uri = uri.asJava } +/** + * INTERNAL API + */ +// http://tools.ietf.org/html/rfc6455#section-4.3 +private[http] object `Sec-WebSocket-Accept` extends ModeledCompanion { + // Defined at http://tools.ietf.org/html/rfc6455#section-4.2.2 + val MagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + /** Generates the matching accept header for this key */ + def forKey(key: `Sec-WebSocket-Key`): `Sec-WebSocket-Accept` = { + val sha1 = MessageDigest.getInstance("sha1") + val salted = key.key + MagicGuid + val hash = sha1.digest(salted.asciiBytes) + val acceptKey = Base64.rfc2045().encodeToString(hash, false) + `Sec-WebSocket-Accept`(acceptKey) + } +} +/** + * INTERNAL API + */ +private[http] final case class `Sec-WebSocket-Accept`(key: String) extends ModeledHeader { + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key + + protected def companion = `Sec-WebSocket-Accept` +} + +/** + * INTERNAL API + */ +// http://tools.ietf.org/html/rfc6455#section-4.3 +private[http] object `Sec-WebSocket-Extensions` extends ModeledCompanion { + implicit val extensionsRenderer = Renderer.defaultSeqRenderer[WebsocketExtension] +} +/** + * INTERNAL API + */ +private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable.Seq[WebsocketExtension]) extends ModeledHeader { + require(extensions.nonEmpty, "Sec-WebSocket-Extensions.extensions must not be empty") + import `Sec-WebSocket-Extensions`.extensionsRenderer + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ extensions + + protected def companion = `Sec-WebSocket-Extensions` +} + +// http://tools.ietf.org/html/rfc6455#section-4.3 +/** + * INTERNAL API + */ +private[http] object `Sec-WebSocket-Key` extends ModeledCompanion +/** + * INTERNAL API + */ +private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledHeader { + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key + + protected def companion = `Sec-WebSocket-Key` +} + +// http://tools.ietf.org/html/rfc6455#section-4.3 +/** + * INTERNAL API + */ +private[http] object `Sec-WebSocket-Protocol` extends ModeledCompanion { + implicit val protocolsRenderer = Renderer.defaultSeqRenderer[String] +} +/** + * INTERNAL API + */ +private[http] final case class `Sec-WebSocket-Protocol`(protocols: immutable.Seq[String]) extends ModeledHeader { + require(protocols.nonEmpty, "Sec-WebSocket-Protocol.protocols must not be empty") + import `Sec-WebSocket-Protocol`.protocolsRenderer + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ protocols + + protected def companion = `Sec-WebSocket-Protocol` +} + +// http://tools.ietf.org/html/rfc6455#section-4.3 +/** + * INTERNAL API + */ +private[http] object `Sec-WebSocket-Version` extends ModeledCompanion { + implicit val versionsRenderer = Renderer.defaultSeqRenderer[Int] +} +/** + * INTERNAL API + */ +private[http] final case class `Sec-WebSocket-Version`(versions: immutable.Seq[Int]) extends ModeledHeader { + require(versions.nonEmpty, "Sec-WebSocket-Version.versions must not be empty") + require(versions.forall(v ⇒ v >= 0 && v <= 255), s"Sec-WebSocket-Version.versions must be in the range 0 <= version <= 255 but were $versions") + import `Sec-WebSocket-Version`.versionsRenderer + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ versions + + def hasVersion(versionNumber: Int): Boolean = versions.exists(_ == versionNumber) + + protected def companion = `Sec-WebSocket-Version` +} + // http://tools.ietf.org/html/rfc7231#section-7.4.2 object Server extends ModeledCompanion { def apply(products: String): Server = apply(ProductVersion.parseMultiple(products)) diff --git a/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala b/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala index 1c1f822aea..f7d8407cec 100644 --- a/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala +++ b/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala @@ -26,7 +26,8 @@ private[http] class HeaderParser(val input: ParserInput) extends Parser with Dyn with IpAddressParsing with LinkHeader with SimpleHeaders - with StringBuilding { + with StringBuilding + with WebsocketHeaders { import CharacterClasses._ // http://www.rfc-editor.org/errata_search.php?rfc=7230 errata id 4189 @@ -111,6 +112,11 @@ private[http] object HeaderParser { "range", "referer", "server", + "sec-websocket-accept", + "sec-websocket-extensions", + "sec-websocket-key", + "sec-websocket-protocol", + "sec-websocket-version", "set-cookie", "transfer-encoding", "user-agent", diff --git a/akka-http-core/src/main/scala/akka/http/model/parser/WebsocketHeaders.scala b/akka-http-core/src/main/scala/akka/http/model/parser/WebsocketHeaders.scala new file mode 100644 index 0000000000..3df0bee907 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/model/parser/WebsocketHeaders.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.model.parser + +import akka.http.model.headers._ +import akka.parboiled2._ + +// see grammar at http://tools.ietf.org/html/rfc6455#section-4.3 +private[parser] trait WebsocketHeaders { this: Parser with CommonRules with CommonActions ⇒ + import CharacterClasses._ + import Base64Parsing.rfc2045Alphabet + + def `sec-websocket-accept` = rule { + `base64-value-non-empty` ~ EOI ~> (`Sec-WebSocket-Accept`(_)) + } + + def `sec-websocket-extensions` = rule { + oneOrMore(extension).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Extensions`(_)) + } + + def `sec-websocket-key` = rule { + `base64-value-non-empty` ~ EOI ~> (`Sec-WebSocket-Key`(_)) + } + + def `sec-websocket-protocol` = rule { + oneOrMore(token).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Protocol`(_)) + } + + def `sec-websocket-version` = rule { + oneOrMore(version).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Version`(_)) + } + + private def `base64-value-non-empty` = rule { + capture(oneOrMore(`base64-data`) ~ optional(`base64-padding`) | `base64-padding`) + } + private def `base64-data` = rule { 4.times(`base64-character`) } + private def `base64-padding` = rule { + 2.times(`base64-character`) ~ "==" | + 3.times(`base64-character`) ~ "=" + } + private def `base64-character` = rfc2045Alphabet + + private def extension = rule { + `extension-token` ~ zeroOrMore(ws(";") ~ `extension-param`) ~> + ((name, params) ⇒ WebsocketExtension(name, Map(params: _*))) + } + private def `extension-token`: Rule1[String] = token + private def `extension-param`: Rule1[(String, String)] = + rule { + token ~ optional(ws("=") ~ word) ~> ((name: String, value: Option[String]) ⇒ (name, value.getOrElse(""))) + } + + private def version = rule { + capture( + NZDIGIT ~ optional(DIGIT ~ optional(DIGIT)) | + DIGIT) ~> (_.toInt) + } + private def NZDIGIT = DIGIT19 +} diff --git a/akka-http-core/src/main/scala/akka/http/util/Rendering.scala b/akka-http-core/src/main/scala/akka/http/util/Rendering.scala index 9e1333e78e..39eba4ab39 100644 --- a/akka-http-core/src/main/scala/akka/http/util/Rendering.scala +++ b/akka-http-core/src/main/scala/akka/http/util/Rendering.scala @@ -84,6 +84,9 @@ private[http] object Renderer { implicit object CharRenderer extends Renderer[Char] { def render[R <: Rendering](r: R, value: Char): r.type = r ~~ value } + implicit object IntRenderer extends Renderer[Int] { + def render[R <: Rendering](r: R, value: Int): r.type = r ~~ value + } implicit object StringRenderer extends Renderer[String] { def render[R <: Rendering](r: R, value: String): r.type = r ~~ value } diff --git a/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala b/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala index 28b4655560..399a4e8c20 100644 --- a/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala @@ -369,6 +369,44 @@ class HttpHeaderSpec extends FreeSpec with Matchers { "Range: bytes=0-1, 2-3, -99" =!= Range(ByteRange(0, 1), ByteRange(2, 3), ByteRange.suffix(99)) } + "Sec-WebSocket-Accept" in { + "Sec-WebSocket-Accept: ZGgwOTM0Z2owcmViamRvcGcK" =!= `Sec-WebSocket-Accept`("ZGgwOTM0Z2owcmViamRvcGcK") + } + "Sec-WebSocket-Extensions" in { + "Sec-WebSocket-Extensions: abc" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc"))) + "Sec-WebSocket-Extensions: abc, def" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc"), WebsocketExtension("def"))) + "Sec-WebSocket-Extensions: abc; param=2; use_y, def" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc", Map("param" -> "2", "use_y" -> "")), WebsocketExtension("def"))) + "Sec-WebSocket-Extensions: abc; param=\",xyz\", def" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc", Map("param" -> ",xyz")), WebsocketExtension("def"))) + + // real examples from https://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-19 + "Sec-WebSocket-Extensions: permessage-deflate" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("permessage-deflate"))) + "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits; server_max_window_bits=10" =!= + `Sec-WebSocket-Extensions`(Vector(WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> "", "server_max_window_bits" -> "10")))) + "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits" =!= + `Sec-WebSocket-Extensions`(Vector( + WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> "", "server_max_window_bits" -> "10")), + WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> "")))) + } + "Sec-WebSocket-Key" in { + "Sec-WebSocket-Key: c2Zxb3JpbmgyMzA5dGpoMDIzOWdlcm5vZ2luCg==" =!= `Sec-WebSocket-Key`("c2Zxb3JpbmgyMzA5dGpoMDIzOWdlcm5vZ2luCg==") + } + "Sec-WebSocket-Protocol" in { + "Sec-WebSocket-Protocol: chat" =!= `Sec-WebSocket-Protocol`(Vector("chat")) + "Sec-WebSocket-Protocol: chat, superchat" =!= `Sec-WebSocket-Protocol`(Vector("chat", "superchat")) + } + "Sec-WebSocket-Version" in { + "Sec-WebSocket-Version: 25" =!= `Sec-WebSocket-Version`(Vector(25)) + "Sec-WebSocket-Version: 13, 8, 7" =!= `Sec-WebSocket-Version`(Vector(13, 8, 7)) + + "Sec-WebSocket-Version: 255" =!= `Sec-WebSocket-Version`(Vector(255)) + "Sec-WebSocket-Version: 0" =!= `Sec-WebSocket-Version`(Vector(0)) + } + "Set-Cookie" in { "Set-Cookie: SID=\"31d4d96e407aad42\"" =!= `Set-Cookie`(HttpCookie("SID", "31d4d96e407aad42")).renderedTo("SID=31d4d96e407aad42") From 4e3c1db4bb5b26dfc73fa821471c9d3c4d81ee8c Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Wed, 18 Feb 2015 17:26:59 +0100 Subject: [PATCH 02/13] +htc #16887 add Upgrade header parsing/rendering --- .../akka/http/model/headers/UpgradeProtocol.scala | 15 +++++++++++++++ .../scala/akka/http/model/headers/headers.scala | 14 ++++++++++++++ .../akka/http/model/parser/HeaderParser.scala | 1 + .../akka/http/model/parser/SimpleHeaders.scala | 9 +++++++++ .../akka/http/model/parser/HttpHeaderSpec.scala | 7 +++++++ 5 files changed, 46 insertions(+) create mode 100644 akka-http-core/src/main/scala/akka/http/model/headers/UpgradeProtocol.scala diff --git a/akka-http-core/src/main/scala/akka/http/model/headers/UpgradeProtocol.scala b/akka-http-core/src/main/scala/akka/http/model/headers/UpgradeProtocol.scala new file mode 100644 index 0000000000..f506c1bcda --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/model/headers/UpgradeProtocol.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.model.headers + +import akka.http.util.{ Rendering, ValueRenderable } + +final case class UpgradeProtocol(name: String, version: Option[String] = None) extends ValueRenderable { + def render[R <: Rendering](r: R): r.type = { + r ~~ name + version.foreach(v ⇒ r ~~ '/' ~~ v) + r + } +} diff --git a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala index 0f0e29eade..55d928b5c4 100644 --- a/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/model/headers/headers.scala @@ -55,6 +55,7 @@ final case class Connection(tokens: immutable.Seq[String]) extends ModeledHeader def renderValue[R <: Rendering](r: R): r.type = r ~~ tokens def hasClose = has("close") def hasKeepAlive = has("keep-alive") + def hasUpgrade = has("upgrade") def append(tokens: immutable.Seq[String]) = Connection(this.tokens ++ tokens) @tailrec private def has(item: String, ix: Int = 0): Boolean = if (ix < tokens.length) @@ -711,6 +712,19 @@ final case class `Transfer-Encoding`(encodings: immutable.Seq[TransferEncoding]) def getEncodings: Iterable[japi.TransferEncoding] = encodings.asJava } +// http://tools.ietf.org/html/rfc7230#section-6.7 +object Upgrade extends ModeledCompanion { + implicit val protocolsRenderer = Renderer.defaultSeqRenderer[UpgradeProtocol] +} +final case class Upgrade(protocols: immutable.Seq[UpgradeProtocol]) extends ModeledHeader { + import Upgrade.protocolsRenderer + protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ protocols + + protected def companion: ModeledCompanion = Upgrade + + def hasWebsocket: Boolean = protocols.exists(_.name equalsIgnoreCase "websocket") +} + // http://tools.ietf.org/html/rfc7231#section-5.5.3 object `User-Agent` extends ModeledCompanion { def apply(products: String): `User-Agent` = apply(ProductVersion.parseMultiple(products)) diff --git a/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala b/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala index f7d8407cec..4c87d634c8 100644 --- a/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala +++ b/akka-http-core/src/main/scala/akka/http/model/parser/HeaderParser.scala @@ -119,6 +119,7 @@ private[http] object HeaderParser { "sec-websocket-version", "set-cookie", "transfer-encoding", + "upgrade", "user-agent", "www-authenticate", "x-forwarded-for") diff --git a/akka-http-core/src/main/scala/akka/http/model/parser/SimpleHeaders.scala b/akka-http-core/src/main/scala/akka/http/model/parser/SimpleHeaders.scala index f7af9acefe..0a0142dc46 100644 --- a/akka-http-core/src/main/scala/akka/http/model/parser/SimpleHeaders.scala +++ b/akka-http-core/src/main/scala/akka/http/model/parser/SimpleHeaders.scala @@ -187,6 +187,15 @@ private[parser] trait SimpleHeaders { this: Parser with CommonRules with CommonA `cookie-pair` ~ zeroOrMore(ws(';') ~ `cookie-av`) ~ EOI ~> (`Set-Cookie`(_)) } + // http://tools.ietf.org/html/rfc7230#section-6.7 + def upgrade = rule { + oneOrMore(protocol).separatedBy(listSep) ~> (Upgrade(_)) + } + + def protocol = rule { + token ~ optional(ws("/") ~ token) ~> (UpgradeProtocol(_, _)) + } + // http://tools.ietf.org/html/rfc7231#section-5.5.3 def `user-agent` = rule { products ~> (`User-Agent`(_)) } diff --git a/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala b/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala index 399a4e8c20..3406f51953 100644 --- a/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/model/parser/HttpHeaderSpec.scala @@ -483,6 +483,13 @@ class HttpHeaderSpec extends FreeSpec with Matchers { .renderedTo("PLAY_FLASH=; Expires=Sun, 07 Dec 2014 22:48:47 GMT; Path=/; HttpOnly") } + "Upgrade" in { + "Upgrade: abc, def" =!= Upgrade(Vector(UpgradeProtocol("abc"), UpgradeProtocol("def"))) + "Upgrade: abc, def/38.1" =!= Upgrade(Vector(UpgradeProtocol("abc"), UpgradeProtocol("def", Some("38.1")))) + + "Upgrade: websocket" =!= Upgrade(Vector(UpgradeProtocol("websocket"))) + } + "User-Agent" in { "User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_3) AppleWebKit/537.31" =!= `User-Agent`(ProductVersion("Mozilla", "5.0", "Macintosh; Intel Mac OS X 10_8_3"), ProductVersion("AppleWebKit", "537.31")) From 7b72fca3c2d023db4c067950f7e336bd863cfc6d Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 14:29:30 +0200 Subject: [PATCH 03/13] +htc introduce big endian read methods to ByteReader --- .../src/main/scala/akka/http/util/ByteReader.scala | 12 ++++++++++-- akka-http/src/main/scala/akka/http/coding/Gzip.scala | 8 ++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala b/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala index e68b867719..d9043e4946 100644 --- a/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala +++ b/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala @@ -18,6 +18,8 @@ private[akka] class ByteReader(input: ByteString) { private[this] var off = 0 + def hasRemaining: Boolean = off < input.size + def currentOffset: Int = off def remainingData: ByteString = input.drop(off) def fromStartToHere: ByteString = input.take(currentOffset) @@ -28,8 +30,14 @@ private[akka] class ByteReader(input: ByteString) { off += 1 x.toInt & 0xFF } else throw NeedMoreData - def readShort(): Int = readByte() | (readByte() << 8) - def readInt(): Int = readShort() | (readShort() << 16) + def readShortLE(): Int = readByte() | (readByte() << 8) + def readIntLE(): Int = readShortLE() | (readShortLE() << 16) + def readLongLE(): Long = (readIntBE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) + + def readShortBE(): Int = (readByte() << 8) | readByte() + def readIntBE(): Int = (readShortBE() << 16) | readShortBE() + def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL) + def skip(numBytes: Int): Unit = if (off + numBytes <= input.length) off += numBytes else throw NeedMoreData diff --git a/akka-http/src/main/scala/akka/http/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/coding/Gzip.scala index bb7cda0dae..f755502f42 100644 --- a/akka-http/src/main/scala/akka/http/coding/Gzip.scala +++ b/akka-http/src/main/scala/akka/http/coding/Gzip.scala @@ -88,10 +88,10 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method val flags = readByte() skip(6) // skip MTIME, XFL and OS fields - if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields + if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment - if ((flags & 2) > 0 && crc16(fromStartToHere) != readShort()) fail("Corrupt GZIP header") + if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header") inflater.reset() crc32.reset() @@ -107,8 +107,8 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = { import reader._ - if (readInt() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") - if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") + if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") + if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") becomeWithRemaining(Initial, remainingData, ctx) } From d67b5823e629469fe0ec5b926a2990989f754d17 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 14:32:31 +0200 Subject: [PATCH 04/13] =htc #17129 Preliminary fix until this is fixed with the new features from #16168 --- akka-http-core/src/main/scala/akka/http/Http.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/akka-http-core/src/main/scala/akka/http/Http.scala b/akka-http-core/src/main/scala/akka/http/Http.scala index 2ff409ce32..f7707ac3f0 100644 --- a/akka-http-core/src/main/scala/akka/http/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/Http.scala @@ -40,10 +40,10 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E val connections: Source[StreamTcp.IncomingConnection, Future[StreamTcp.ServerBinding]] = StreamTcp().bind(endpoint, backlog, options, effectiveSettings.timeouts.idleTimeout) - val layer = serverLayer(effectiveSettings, log) connections.map { case StreamTcp.IncomingConnection(localAddress, remoteAddress, flow) ⇒ + val layer = serverLayer(effectiveSettings, log) IncomingConnection(localAddress, remoteAddress, layer join flow) }.mapMaterialized { tcpBindingFuture ⇒ import system.dispatcher From a58859c77c98fca3e041b5480ddf8cd3a6467ab6 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 14:42:36 +0200 Subject: [PATCH 05/13] +str #17145 add new Flow.wrap overload to create flow from sink and source --- akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 89aefea5fc..072f0cc965 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -300,6 +300,11 @@ object Flow extends FlowApply { */ def wrap[I, O, M](g: Graph[FlowShape[I, O], M]): Flow[I, O, M] = new Flow(g.module) + /** + * Helper to create `Flow` from a pair of sink and source. + */ + def wrap[I, O, M1, M2, M](sink: Sink[I, M1], source: Source[O, M2])(f: (M1, M2) ⇒ M): Flow[I, O, M] = + Flow(sink, source)(f) { implicit b ⇒ (in, out) ⇒ (in.inlet, out.outlet) } } /** From 2402bedc6a7e8bf058b4e49d36e7cba39f39e8ef Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 14:54:00 +0200 Subject: [PATCH 06/13] +htc #16887 add server-side websocket API and add example to TestServer --- .../scala/akka/http/model/ws/Message.scala | 29 ++++++++++ .../http/model/ws/UpgradeToWebsocket.scala | 28 ++++++++++ .../src/test/scala/akka/http/TestServer.scala | 56 +++++++++++++++---- 3 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/model/ws/Message.scala create mode 100644 akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala diff --git a/akka-http-core/src/main/scala/akka/http/model/ws/Message.scala b/akka-http-core/src/main/scala/akka/http/model/ws/Message.scala new file mode 100644 index 0000000000..a8b73116f7 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/model/ws/Message.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.model.ws + +import akka.stream.scaladsl.Source +import akka.util.ByteString + +/** + * The ADT for Websocket messages. A message can either be binary or a text message. Each of + * those can either be strict or streamed. + */ +sealed trait Message +sealed trait TextMessage extends Message +object TextMessage { + final case class Strict(text: String) extends TextMessage { + override def toString: String = s"TextMessage.Strict($text)" + } + final case class Streamed(textStream: Source[String, _]) extends TextMessage +} + +sealed trait BinaryMessage extends Message +object BinaryMessage { + final case class Strict(data: ByteString) extends BinaryMessage { + override def toString: String = s"BinaryMessage.Strict($data)" + } + final case class Streamed(dataStream: Source[ByteString, _]) extends BinaryMessage +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala new file mode 100644 index 0000000000..c63b4ea2a4 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.model.ws + +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +import akka.http.model.{ HttpHeader, HttpResponse } + +/** + * A custom header that will be added to an Websocket upgrade HttpRequest that + * enables a request handler to upgrade this connection to a Websocket connection and + * registers a Websocket handler. + * + * FIXME: needs to be able to choose subprotocols as possibly agreed on in the websocket handshake + */ +trait UpgradeToWebsocket extends HttpHeader { + /** + * The high-level interface to create a Websocket server based on "messages". + * + * Returns a response to return in a request handler that will signal the + * low-level HTTP implementation to upgrade the connection to Websocket and + * use the supplied handler to handle incoming Websocket messages. + */ + def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse +} diff --git a/akka-http-core/src/test/scala/akka/http/TestServer.scala b/akka-http-core/src/test/scala/akka/http/TestServer.scala index c22be6dee9..07d29179b0 100644 --- a/akka-http-core/src/test/scala/akka/http/TestServer.scala +++ b/akka-http-core/src/test/scala/akka/http/TestServer.scala @@ -4,12 +4,17 @@ package akka.http +import scala.concurrent.duration._ import akka.actor.ActorSystem import akka.http.model._ +import akka.http.model.ws._ import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl.{ Source, Flow } import com.typesafe.config.{ ConfigFactory, Config } import HttpMethods._ +import scala.concurrent.Await + object TestServer extends App { val testConf: Config = ConfigFactory.parseString(""" akka.loglevel = INFO @@ -18,18 +23,36 @@ object TestServer extends App { implicit val system = ActorSystem("ServerTest", testConf) implicit val fm = ActorFlowMaterializer() - val binding = Http().bindAndHandleSync({ - case HttpRequest(GET, Uri.Path("/"), _, _, _) ⇒ index - case HttpRequest(GET, Uri.Path("/ping"), _, _, _) ⇒ HttpResponse(entity = "PONG!") - case HttpRequest(GET, Uri.Path("/crash"), _, _, _) ⇒ sys.error("BOOM!") - case _: HttpRequest ⇒ HttpResponse(404, entity = "Unknown resource!") - }, interface = "localhost", port = 8080) + try { + val binding = Http().bindAndHandleSync({ + case req @ HttpRequest(GET, Uri.Path("/"), _, _, _) if req.header[UpgradeToWebsocket].isDefined ⇒ + req.header[UpgradeToWebsocket] match { + case Some(upgrade) ⇒ upgrade.handleMessages(echoWebsocketService) // needed for running the autobahn test suite + case None ⇒ HttpResponse(400, entity = "Not a valid websocket request!") + } + case req @ HttpRequest(GET, Uri.Path("/ws-greeter"), _, _, _) ⇒ + req.header[UpgradeToWebsocket] match { + case Some(upgrade) ⇒ upgrade.handleMessages(greeterWebsocketService) + case None ⇒ HttpResponse(400, entity = "Not a valid websocket request!") + } + case HttpRequest(GET, Uri.Path("/"), _, _, _) ⇒ index + case HttpRequest(GET, Uri.Path("/ping"), _, _, _) ⇒ HttpResponse(entity = "PONG!") + case HttpRequest(GET, Uri.Path("/crash"), _, _, _) ⇒ sys.error("BOOM!") + case req @ HttpRequest(GET, Uri.Path("/ws-greeter"), _, _, _) ⇒ + req.header[UpgradeToWebsocket] match { + case Some(upgrade) ⇒ upgrade.handleMessages(greeterWebsocketService) + case None ⇒ HttpResponse(400, entity = "Not a valid websocket request!") + } + case _: HttpRequest ⇒ HttpResponse(404, entity = "Unknown resource!") + }, interface = "localhost", port = 9001) - println(s"Server online at http://localhost:8080") - println("Press RETURN to stop...") - Console.readLine() - - system.shutdown() + Await.result(binding, 1.second) // throws if binding fails + println("Server online at http://localhost:9001") + println("Press RETURN to stop...") + Console.readLine() + } finally { + system.shutdown() + } ////////////// helpers ////////////// @@ -45,4 +68,15 @@ object TestServer extends App { | | |""".stripMargin)) + + def echoWebsocketService: Flow[Message, Message, Unit] = + Flow[Message] // just let message flow directly to the output + + def greeterWebsocketService: Flow[Message, Message, Unit] = + Flow[Message] + .collect { + case TextMessage.Strict(name) ⇒ TextMessage.Strict(s"Hello '$name'") + case TextMessage.Streamed(nameStream) ⇒ TextMessage.Streamed(Source.single("Hello ") ++ nameStream mapMaterialized (_ ⇒ ())) + // ignore binary messages + } } From 23a5dadee059930ae9cf08a043709896c6acc548 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:02:54 +0200 Subject: [PATCH 07/13] =htc #16887 a Utf8 encoder and decoder --- .../akka/http/engine/ws/Utf8Decoder.scala | 121 ++++++++++++++++++ .../akka/http/engine/ws/Utf8Encoder.scala | 83 ++++++++++++ .../akka/http/engine/ws/Utf8CodingSpecs.scala | 60 +++++++++ .../http/engine/ws/WithMaterializerSpec.scala | 20 +++ 4 files changed, 284 insertions(+) create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Decoder.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Encoder.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/Utf8CodingSpecs.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/WithMaterializerSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Decoder.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Decoder.scala new file mode 100644 index 0000000000..daa5dbc9f2 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Decoder.scala @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.util.ByteString + +import scala.annotation.tailrec +import scala.util.Try + +/** + * A Utf8 -> Utf16 (= Java char) decoder. + * + * This decoder is based on the one of Bjoern Hoehrmann from + * + * http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + * + * which is licensed under this license: + * + * Copyright (c) 2008-2010 Bjoern Hoehrmann + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * FIXME: reviewers: it may be necessary to distribute this notice in the license file, is it? + * + * INTERNAL API + */ +private[http] object Utf8Decoder extends StreamingCharsetDecoder { + private[this] val Utf8Accept = 0 + private[this] val Utf8Reject = 12 + + val characterClasses = + Array[Byte]( + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8) + + val states = + Array[Byte]( + 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, + 12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12) + + def create(): StreamingCharsetDecoderInstance = + new StreamingCharsetDecoderInstance { + var currentCodePoint = 0 + var currentState = Utf8Accept + + def decode(bytes: ByteString, endOfInput: Boolean): Try[String] = Try { + val result = new StringBuilder(bytes.size) + val length = bytes.size + + def step(byte: Int): Unit = { + val chClass = characterClasses(byte) + currentCodePoint = + if (currentState == Utf8Accept) // first byte + (0xff >> chClass) & byte // take as much bits as the characterClass says + else // continuation byte + (0x3f & byte) | (currentCodePoint << 6) // take 6 bits + currentState = states(currentState + chClass) + + currentState match { + case Utf8Accept ⇒ + if (currentCodePoint <= 0xffff) + // fits in single UTF-16 char + result.append(currentCodePoint.toChar) + else { + // create surrogate pair + result.append((0xD7C0 + (currentCodePoint >> 10)).toChar) + result.append((0xDC00 + (currentCodePoint & 0x3FF)).toChar) + } + case Utf8Reject ⇒ fail("Invalid UTF-8 input") + case _ ⇒ // valid intermediate state, need more input + } + } + + var offset = 0 + while (offset < length) { + step(bytes(offset) & 0xff) + offset += 1 + } + + if (endOfInput && currentState != Utf8Accept) fail("Truncated UTF-8 input") + else + result.toString() + } + + def fail(msg: String): Nothing = throw new IllegalArgumentException(msg) + } +} + +private[http] trait StreamingCharsetDecoder { + def create(): StreamingCharsetDecoderInstance + def decode(bytes: ByteString): Try[String] = create().decode(bytes, endOfInput = true) +} +private[http] trait StreamingCharsetDecoderInstance { + def decode(bytes: ByteString, endOfInput: Boolean): Try[String] +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Encoder.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Encoder.scala new file mode 100644 index 0000000000..a4426a7874 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Utf8Encoder.scala @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.stage._ +import akka.util.{ ByteStringBuilder, ByteString } + +import scala.annotation.tailrec + +/** + * A utf16 (= Java char) to utf8 encoder. + * + * INTERNAL API + */ +private[http] class Utf8Encoder extends PushStage[String, ByteString] { + import Utf8Encoder._ + + var surrogateValue: Int = 0 + def inSurrogatePair: Boolean = surrogateValue != 0 + + def onPush(input: String, ctx: Context[ByteString]): SyncDirective = { + val builder = new ByteStringBuilder + + def b(v: Int): Unit = { + assert((v & 0xff) == v) + builder += v.toByte + } + + def step(char: Int): Unit = + if (!inSurrogatePair) + if (char <= Utf8OneByteLimit) builder += char.toByte + else if (char <= Utf8TwoByteLimit) { + b(0xc0 | ((char & 0x7c0) >> 6)) // upper 5 bits + b(0x80 | (char & 0x3f)) // lower 6 bits + } else if (char >= SurrogateFirst && char < SurrogateSecond) + surrogateValue = 0x10000 | ((char ^ SurrogateFirst) << 10) + else if (char >= SurrogateSecond && char < 0xdfff) + throw new IllegalArgumentException(f"Unexpected UTF-16 surrogate continuation") + else if (char <= Utf8ThreeByteLimit) { + b(0xe0 | ((char & 0xf000) >> 12)) // upper 4 bits + b(0x80 | ((char & 0x0fc0) >> 6)) // middle 6 bits + b(0x80 | (char & 0x3f)) // lower 6 bits + } else + throw new IllegalStateException("Char cannot be >= 2^16") // char value was converted from 16bit value + else if (char >= SurrogateSecond && char <= 0xdfff) { + surrogateValue |= (char & 0x3ff) + b(0xf0 | ((surrogateValue & 0x1c0000) >> 18)) // upper 3 bits + b(0x80 | ((surrogateValue & 0x3f000) >> 12)) // first middle 6 bits + b(0x80 | ((surrogateValue & 0x0fc0) >> 6)) // second middle 6 bits + b(0x80 | (surrogateValue & 0x3f)) // lower 6 bits + surrogateValue = 0 + } else throw new IllegalArgumentException(f"Expected UTF-16 surrogate continuation") + + var offset = 0 + while (offset < input.length) { + step(input(offset)) + offset += 1 + } + + if (builder.length > 0) ctx.push(builder.result()) + else ctx.pull() + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = + if (inSurrogatePair) ctx.fail(new IllegalArgumentException("Truncated String input (ends in the middle of surrogate pair)")) + else super.onUpstreamFinish(ctx) +} + +/** + * INTERNAL API + */ +private[http] object Utf8Encoder { + val SurrogateFirst = 0xd800 + val SurrogateSecond = 0xdc00 + + val Utf8OneByteLimit = lowerNBitsSet(7) + val Utf8TwoByteLimit = lowerNBitsSet(11) + val Utf8ThreeByteLimit = lowerNBitsSet(16) + + def lowerNBitsSet(n: Int): Long = (1L << n) - 1 +} \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/Utf8CodingSpecs.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/Utf8CodingSpecs.scala new file mode 100644 index 0000000000..df99981e2e --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/Utf8CodingSpecs.scala @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import org.scalacheck.Gen + +import scala.concurrent.duration._ + +import akka.stream.scaladsl.Source +import akka.util.ByteString +import akka.http.util._ +import org.scalatest.prop.PropertyChecks +import org.scalatest.{ FreeSpec, Matchers } + +class Utf8CodingSpecs extends FreeSpec with Matchers with PropertyChecks with WithMaterializerSpec { + "Utf8 decoding/encoding" - { + "work for all codepoints" in { + def isSurrogate(cp: Int): Boolean = + cp >= Utf8Encoder.SurrogateFirst && cp <= 0xdfff + + val cps = + Gen.choose(0, 0x10ffff) + .filter(!isSurrogate(_)) + + def codePointAsString(cp: Int): String = { + if (cp < 0x10000) new String(Array(cp.toChar)) + else { + val part0 = 0xd7c0 + (cp >> 10) // constant has 0x10000 subtracted already + val part1 = 0xdc00 + (cp & 0x3ff) + new String(Array(part0.toChar, part1.toChar)) + } + } + + forAll(cps) { (cp: Int) ⇒ + val utf16 = codePointAsString(cp) + decodeUtf8(encodeUtf8(utf16)) === utf16 + } + } + } + + def encodeUtf8(str: String): ByteString = + Source(str.map(ch ⇒ new String(Array(ch)))) // chunk in smallest chunks possible + .transform(() ⇒ new Utf8Encoder) + .runFold(ByteString.empty)(_ ++ _).awaitResult(1.second) + + def decodeUtf8(bytes: ByteString): String = { + val builder = new StringBuilder + val decoder = Utf8Decoder.create() + bytes + .map(b ⇒ ByteString(b)) // chunk in smallest chunks possible + .foreach { bs ⇒ + builder append decoder.decode(bs, endOfInput = false).get + } + + builder append decoder.decode(ByteString.empty, endOfInput = true).get + builder.toString() + } +} diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/WithMaterializerSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/WithMaterializerSpec.scala new file mode 100644 index 0000000000..ac165086fe --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/WithMaterializerSpec.scala @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.actor.ActorSystem +import akka.stream.ActorFlowMaterializer +import com.typesafe.config.{ ConfigFactory, Config } +import org.scalatest.{ Suite, BeforeAndAfterAll } + +trait WithMaterializerSpec extends BeforeAndAfterAll { _: Suite ⇒ + lazy val testConf: Config = ConfigFactory.parseString(""" + akka.event-handlers = ["akka.testkit.TestEventListener"] + akka.loglevel = WARNING""") + implicit lazy val system = ActorSystem(getClass.getSimpleName, testConf) + + implicit lazy val materializer = ActorFlowMaterializer() + override def afterAll() = system.shutdown() +} \ No newline at end of file From 73722c425d03d6cb0638119426c3d73278d47011 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:06:18 +0200 Subject: [PATCH 08/13] =htc #16887 websocket framing implementation --- .../akka/http/engine/ws/FrameEvent.scala | 72 ++++ .../http/engine/ws/FrameEventParser.scala | 162 +++++++++ .../http/engine/ws/FrameEventRenderer.scala | 100 ++++++ .../scala/akka/http/engine/ws/Protocol.scala | 84 +++++ .../akka/http/engine/ws/BitBuilder.scala | 103 ++++++ .../akka/http/engine/ws/FramingSpec.scala | 326 ++++++++++++++++++ 6 files changed, 847 insertions(+) create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameEvent.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventParser.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventRenderer.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Protocol.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/BitBuilder.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/FramingSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEvent.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEvent.scala new file mode 100644 index 0000000000..e4911a2844 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEvent.scala @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.http.engine.ws.Protocol.Opcode +import akka.util.ByteString + +/** + * The low-level Websocket framing model. + * + * INTERNAL API + */ +private[http] sealed trait FrameEvent { + def data: ByteString + def lastPart: Boolean + def withData(data: ByteString): FrameEvent +} + +/** + * Starts a frame. Contains the frame's headers. May contain all the data of the frame if `lastPart == true`. Otherwise, + * following events will be `FrameData` events that contain the remaining data of the frame. + */ +private[http] final case class FrameStart(header: FrameHeader, data: ByteString) extends FrameEvent { + def lastPart: Boolean = data.size == header.length + def withData(data: ByteString): FrameStart = copy(data = data) + + def isFullMessage: Boolean = header.fin && header.length == data.length +} + +/** + * Frame data that was received after the start of the frame.. + */ +private[http] final case class FrameData(data: ByteString, lastPart: Boolean) extends FrameEvent { + def withData(data: ByteString): FrameData = copy(data = data) +} + +/** Model of the frame header */ +private[http] final case class FrameHeader(opcode: Protocol.Opcode, + mask: Option[Int], + length: Long, + fin: Boolean, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false) + +private[http] object FrameEvent { + def empty(opcode: Protocol.Opcode, + fin: Boolean, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): FrameStart = + fullFrame(opcode, None, ByteString.empty, fin, rsv1, rsv2, rsv3) + def fullFrame(opcode: Protocol.Opcode, mask: Option[Int], data: ByteString, + fin: Boolean, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): FrameStart = + FrameStart(FrameHeader(opcode, mask, data.length, fin, rsv1, rsv2, rsv3), data) + val emptyLastContinuationFrame: FrameStart = + empty(Protocol.Opcode.Continuation, fin = true) + + def closeFrame(closeCode: Int, reason: String = "", mask: Option[Int] = None): FrameStart = { + require(closeCode >= 1000, s"Invalid close code: $closeCode") + val body = ByteString( + ((closeCode & 0xff00) >> 8).toByte, + (closeCode & 0xff).toByte) ++ ByteString(reason, "UTF8") + + fullFrame(Opcode.Close, mask, FrameEventParser.mask(body, mask), fin = true) + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventParser.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventParser.scala new file mode 100644 index 0000000000..3992d51767 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventParser.scala @@ -0,0 +1,162 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.http.util.{ ByteReader, ByteStringParserStage } +import akka.stream.stage.{ StageState, SyncDirective, Context } +import akka.util.ByteString + +import scala.annotation.tailrec + +/** + * Streaming parser for the Websocket framing protocol as defined in RFC6455 + * + * http://tools.ietf.org/html/rfc6455 + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-------+-+-------------+-------------------------------+ + * |F|R|R|R| opcode|M| Payload len | Extended payload length | + * |I|S|S|S| (4) |A| (7) | (16/64) | + * |N|V|V|V| |S| | (if payload len==126/127) | + * | |1|2|3| |K| | | + * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + * | Extended payload length continued, if payload len == 127 | + * + - - - - - - - - - - - - - - - +-------------------------------+ + * | |Masking-key, if MASK set to 1 | + * +-------------------------------+-------------------------------+ + * | Masking-key (continued) | Payload Data | + * +-------------------------------- - - - - - - - - - - - - - - - + + * : Payload Data continued ... : + * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + * | Payload Data continued ... | + * +---------------------------------------------------------------+ + * + * INTERNAL API + */ +private[http] class FrameEventParser extends ByteStringParserStage[FrameEvent] { + protected def onTruncation(ctx: Context[FrameEvent]): SyncDirective = + ctx.fail(new ProtocolException("Data truncated")) + + def initial: StageState[ByteString, FrameEvent] = ReadFrameHeader + + object ReadFrameHeader extends ByteReadingState { + def read(reader: ByteReader, ctx: Context[FrameEvent]): SyncDirective = { + import Protocol._ + + val flagsAndOp = reader.readByte() + val maskAndLength = reader.readByte() + + val flags = flagsAndOp & FLAGS_MASK + val op = flagsAndOp & OP_MASK + + val maskBit = (maskAndLength & MASK_MASK) != 0 + val length7 = maskAndLength & LENGTH_MASK + + val length = + length7 match { + case 126 ⇒ reader.readShortBE().toLong + case 127 ⇒ reader.readLongBE() + case x ⇒ x.toLong + } + + if (length < 0) ctx.fail(new ProtocolException("Highest bit of 64bit length was set")) + + val mask = + if (maskBit) Some(reader.readIntBE()) + else None + + def isFlagSet(mask: Int): Boolean = (flags & mask) != 0 + val header = + FrameHeader(Opcode.forCode(op.toByte), + mask, + length, + fin = isFlagSet(FIN_MASK), + rsv1 = isFlagSet(RSV1_MASK), + rsv2 = isFlagSet(RSV2_MASK), + rsv3 = isFlagSet(RSV3_MASK)) + + val data = reader.remainingData + val takeNow = (header.length min Int.MaxValue).toInt + val thisFrameData = data.take(takeNow) + val remaining = data.drop(takeNow) + + val nextState = + if (thisFrameData.length == length) ReadFrameHeader + else readData(length - thisFrameData.length) + + pushAndBecomeWithRemaining(FrameStart(header, thisFrameData), nextState, remaining, ctx) + } + } + + def readData(_remaining: Long): State = + new State { + var remaining = _remaining + def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = + if (elem.size < remaining) { + remaining -= elem.size + ctx.push(FrameData(elem, lastPart = false)) + } else { + assert(remaining <= Int.MaxValue) // safe because, remaining <= elem.size <= Int.MaxValue + val frameData = elem.take(remaining.toInt) + val remainingData = elem.drop(remaining.toInt) + + pushAndBecomeWithRemaining(FrameData(frameData, lastPart = true), ReadFrameHeader, remainingData, ctx) + } + } + + def becomeWithRemaining(nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = { + become(nextState) + nextState.onPush(remainingData, ctx) + } + def pushAndBecomeWithRemaining(elem: FrameEvent, nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = + if (remainingData.isEmpty) { + become(nextState) + ctx.push(elem) + } else { + become(waitForPull(nextState, remainingData)) + ctx.push(elem) + } + + def waitForPull(nextState: State, remainingData: ByteString): State = + new State { + def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = + throw new IllegalStateException("Mustn't push in this state") + + override def onPull(ctx: Context[FrameEvent]): SyncDirective = { + become(nextState) + nextState.onPush(remainingData, ctx) + } + } +} + +object FrameEventParser { + def mask(bytes: ByteString, _mask: Option[Int]): ByteString = + _mask match { + case Some(m) ⇒ mask(bytes, m)._1 + case None ⇒ bytes + } + def mask(bytes: ByteString, mask: Int): (ByteString, Int) = { + @tailrec def rec(bytes: Array[Byte], offset: Int, mask: Int): Int = + if (offset >= bytes.length) mask + else { + val newMask = Integer.rotateLeft(mask, 8) // we cycle through the mask in BE order + bytes(offset) = (bytes(offset) ^ (newMask & 0xff)).toByte + rec(bytes, offset + 1, newMask) + } + + val buffer = bytes.toArray[Byte] + val newMask = rec(buffer, 0, mask) + (ByteString(buffer), newMask) + } + + def parseCloseCode(data: ByteString): Option[Int] = + if (data.length >= 2) { + val code = ((data(0) & 0xff) << 8) | (data(1) & 0xff) + if (Protocol.CloseCodes.isValid(code)) Some(code) + else Some(Protocol.CloseCodes.ProtocolError) + } else if (data.length == 1) Some(Protocol.CloseCodes.ProtocolError) // must be >= length 2 if not empty + else None +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventRenderer.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventRenderer.scala new file mode 100644 index 0000000000..85c7555c69 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameEventRenderer.scala @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.util.ByteString +import akka.stream.stage.{ TerminationDirective, StatefulStage, SyncDirective, Context } + +import scala.annotation.tailrec + +/** + * Renders FrameEvents to ByteString. + * + * INTERNAL API + */ +private[http] class FrameEventRenderer extends StatefulStage[FrameEvent, ByteString] { + def initial: State = Idle + + object Idle extends State { + def onPush(elem: FrameEvent, ctx: Context[ByteString]): SyncDirective = elem match { + case start @ FrameStart(header, data) ⇒ + assert(header.length >= data.size) + if (!start.lastPart && header.length > 0) become(renderData(header.length - data.length, this)) + + ctx.push(renderStart(start)) + } + } + + def renderData(initialRemaining: Long, nextState: State): State = + new State { + var remaining: Long = initialRemaining + + def onPush(elem: FrameEvent, ctx: Context[ByteString]): SyncDirective = elem match { + case FrameData(data, lastPart) ⇒ + if (data.size > remaining) + throw new IllegalStateException(s"Expected $remaining frame bytes but got ${data.size}") + else if (data.size == remaining) { + if (!lastPart) throw new IllegalStateException(s"Frame data complete but `lastPart` flag not set") + become(nextState) + ctx.push(data) + } else { + remaining -= data.size + ctx.push(data) + } + } + } + + def renderStart(start: FrameStart): ByteString = renderHeader(start.header) ++ start.data + def renderHeader(header: FrameHeader): ByteString = { + import Protocol._ + + val length = header.length + val (lengthBits, extraLengthBytes) = length match { + case x if x < 126 ⇒ (x.toInt, 0) + case x if x <= 0xFFFF ⇒ (126, 2) + case _ ⇒ (127, 8) + } + + val maskBytes = if (header.mask.isDefined) 4 else 0 + val totalSize = 2 + extraLengthBytes + maskBytes + + val data = new Array[Byte](totalSize) + + def bool(b: Boolean, mask: Int): Int = if (b) mask else 0 + val flags = + bool(header.fin, FIN_MASK) | + bool(header.rsv1, RSV1_MASK) | + bool(header.rsv2, RSV2_MASK) | + bool(header.rsv3, RSV3_MASK) + + data(0) = (flags | header.opcode.code).toByte + data(1) = (bool(header.mask.isDefined, MASK_MASK) | lengthBits).toByte + + extraLengthBytes match { + case 0 ⇒ + case 2 ⇒ + data(2) = ((length & 0xFF00) >> 8).toByte + data(3) = ((length & 0x00FF) >> 0).toByte + case 8 ⇒ + @tailrec def addLongBytes(l: Long, writtenBytes: Int): Unit = + if (writtenBytes < 8) { + data(2 + writtenBytes) = (l & 0xff).toByte + addLongBytes(java.lang.Long.rotateLeft(l, 8), writtenBytes + 1) + } + + addLongBytes(java.lang.Long.rotateLeft(length, 8), 0) + } + + val maskOffset = 2 + extraLengthBytes + header.mask.foreach { mask ⇒ + data(maskOffset + 0) = ((mask & 0xFF000000) >> 24).toByte + data(maskOffset + 1) = ((mask & 0x00FF0000) >> 16).toByte + data(maskOffset + 2) = ((mask & 0x0000FF00) >> 8).toByte + data(maskOffset + 3) = ((mask & 0x000000FF) >> 0).toByte + } + + ByteString(data) + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Protocol.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Protocol.scala new file mode 100644 index 0000000000..f9cd40843f --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Protocol.scala @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +/** + * Contains Websocket protocol constants + * + * INTERNAL API + */ +private[http] object Protocol { + val FIN_MASK = 0x80 + val RSV1_MASK = 0x40 + val RSV2_MASK = 0x20 + val RSV3_MASK = 0x10 + + val FLAGS_MASK = 0xF0 + val OP_MASK = 0x0F + + val MASK_MASK = 0x80 + val LENGTH_MASK = 0x7F + + sealed trait Opcode { + def code: Byte + def isControl: Boolean + } + object Opcode { + def forCode(code: Byte): Opcode = code match { + case 0x0 ⇒ Continuation + case 0x1 ⇒ Text + case 0x2 ⇒ Binary + + case 0x8 ⇒ Close + case 0x9 ⇒ Ping + case 0xA ⇒ Pong + + case b if (b & 0xf0) == 0 ⇒ Other(code) + case _ ⇒ throw new IllegalArgumentException(f"Opcode must be 4bit long but was 0x$code%02X") + } + + sealed abstract class AbstractOpcode private[Opcode] (val code: Byte) extends Opcode { + def isControl: Boolean = (code & 0x8) != 0 + } + + case object Continuation extends AbstractOpcode(0x0) + case object Text extends AbstractOpcode(0x1) + case object Binary extends AbstractOpcode(0x2) + + case object Close extends AbstractOpcode(0x8) + case object Ping extends AbstractOpcode(0x9) + case object Pong extends AbstractOpcode(0xA) + + case class Other(override val code: Byte) extends AbstractOpcode(code) + } + + /** + * Close status codes as defined at http://tools.ietf.org/html/rfc6455#section-7.4.1 + */ + object CloseCodes { + def isError(code: Int): Boolean = !(code == Regular || code == GoingAway) + def isValid(code: Int): Boolean = + ((code >= 1000) && (code <= 1003)) || + (code >= 1007) && (code <= 1011) || + (code >= 3000) && (code <= 4999) + + val Regular = 1000 + val GoingAway = 1001 + val ProtocolError = 1002 + val Unacceptable = 1003 + // Reserved = 1004 + // NoCodePresent = 1005 + val ConnectionAbort = 1006 + val InconsistentData = 1007 + val PolicyViolated = 1008 + val TooBig = 1009 + val ClientRejectsExtension = 1010 + val UnexpectedCondition = 1011 + val TLSHandshakeFailure = 1015 + } +} + +/** INTERNAL API */ +private[http] case class ProtocolException(cause: String) extends RuntimeException(cause) \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/BitBuilder.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/BitBuilder.scala new file mode 100644 index 0000000000..4edd0181ee --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/BitBuilder.scala @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.parboiled2 +import parboiled2._ +import akka.util.ByteString + +import scala.annotation.tailrec +import scala.util.{ Try, Failure, Success } + +object BitBuilder { + implicit class BitBuilderContext(val ctx: StringContext) { + def b(args: Any*): ByteString = { + val parser = new BitSpecParser(ctx.parts.mkString) + val bits = parser.parseBits() + //println(bits) + //println(bits.get.toByteString.map(_ formatted "%02x").mkString(" ")) + bits.get.toByteString + } + } +} + +final case class Bits(elements: Seq[Bits.BitElement]) { + def toByteString: ByteString = { + import Bits._ + + val bits = elements.map(_.bits).sum + + require(bits % 8 == 0) + val data = new Array[Byte](bits / 8) + @tailrec def rec(byteIdx: Int, bitIdx: Int, remaining: Seq[Bits.BitElement]): Unit = + if (bitIdx >= 8) rec(byteIdx + 1, bitIdx - 8, remaining) + else remaining match { + case Zero +: rest ⇒ + // zero by default + rec(byteIdx, bitIdx + 1, rest) + case One +: rest ⇒ + data(byteIdx) = (data(byteIdx) | (1 << (7 - bitIdx))).toByte + rec(byteIdx, bitIdx + 1, rest) + case Multibit(bits, value) +: rest ⇒ + val numBits = math.min(8 - bitIdx, bits) + val remainingBits = bits - numBits + val highestNBits = value >> remainingBits + val lowestNBitMask = (~(0xff << numBits) & 0xff) + data(byteIdx) = (data(byteIdx) | (highestNBits & lowestNBitMask)).toByte + + if (remainingBits > 0) + rec(byteIdx + 1, 0, Multibit(remainingBits, value) +: rest) + else + rec(byteIdx, bitIdx + numBits, rest) + case Nil ⇒ + require(bitIdx == 0 && byteIdx == bits / 8) + } + rec(0, 0, elements) + + ByteString(data) // this could be ByteString1C + } +} +object Bits { + sealed trait BitElement { + def bits: Int + } + sealed abstract class SingleBit extends BitElement { + def bits: Int = 1 + } + case object Zero extends SingleBit + case object One extends SingleBit + case class Multibit(bits: Int, value: Long) extends BitElement +} + +class BitSpecParser(val input: ParserInput) extends parboiled2.Parser { + import Bits._ + def parseBits(): Try[Bits] = + bits.run() match { + case s: Success[Bits] ⇒ s + case Failure(e: ParseError) ⇒ Failure(new RuntimeException(formatError(e, showTraces = true))) + } + + def bits: Rule1[Bits] = rule { zeroOrMore(element) ~ EOI ~> (Bits(_)) } + + val WSChar = CharPredicate(' ', '\t', '\n') + def ws = rule { zeroOrMore(wsElement) } + def wsElement = rule { WSChar | comment } + def comment = + rule { + '#' ~ zeroOrMore(!'\n' ~ ANY) ~ '\n' + } + + def element: Rule1[BitElement] = rule { + zero | one | multi + } + def zero: Rule1[BitElement] = rule { '0' ~ push(Zero) ~ ws } + def one: Rule1[BitElement] = rule { '1' ~ push(One) ~ ws } + def multi: Rule1[Multibit] = rule { + capture(oneOrMore('x' ~ ws)) ~> (_.count(_ == 'x')) ~ '=' ~ value ~ ws ~> Multibit + } + def value: Rule1[Long] = rule { + capture(oneOrMore(CharPredicate.HexDigit)) ~> ((str: String) ⇒ java.lang.Long.parseLong(str, 16)) + } +} \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/FramingSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/FramingSpec.scala new file mode 100644 index 0000000000..14c7e3a444 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/FramingSpec.scala @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import scala.collection.immutable +import scala.concurrent.duration._ + +import org.scalatest.matchers.Matcher +import org.scalatest.{ FreeSpec, Matchers } + +import akka.util.ByteString + +import akka.stream.scaladsl.Source +import akka.stream.stage.Stage + +import akka.http.util._ + +import Protocol.Opcode + +class FramingSpec extends FreeSpec with Matchers with WithMaterializerSpec { + import BitBuilder._ + + "The Websocket parser/renderer round-trip should work for" - { + "the frame header" - { + "interpret flags correctly" - { + "FIN" in { + b"""1000 # flags + 0000 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = true)) + } + "RSV1" in { + b"""0100 # flags + 0000 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv1 = true)) + } + "RSV2" in { + b"""0010 # flags + 0000 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv2 = true)) + } + "RSV3" in { + b"""0001 # flags + 0000 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv3 = true)) + } + } + "interpret opcode correctly" - { + "Continuation" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false)) + } + "Text" in { + b"""0000 # flags + xxxx=1 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Text, None, 0, fin = false)) + } + "Binary" in { + b"""0000 # flags + xxxx=2 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Binary, None, 0, fin = false)) + } + + "Close" in { + b"""0000 # flags + xxxx=8 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Close, None, 0, fin = false)) + } + "Ping" in { + b"""0000 # flags + xxxx=9 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Ping, None, 0, fin = false)) + } + "Pong" in { + b"""0000 # flags + xxxx=a # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Pong, None, 0, fin = false)) + } + + "Other" in { + b"""0000 # flags + xxxx=6 # opcode + 0 # mask? + 0000000 # length + """ should parseTo(FrameHeader(Opcode.Other(6), None, 0, fin = false)) + } + } + "read mask correctly" in { + b"""0000 # flags + 0000 # opcode + 1 # mask? + 0000000 # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=a1b2c3d4 + """ should parseTo(FrameHeader(Opcode.Continuation, Some(0xa1b2c3d4), 0, fin = false)) + } + "read length" - { + "< 126" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=5 # length + """ should parseTo(FrameHeader(Opcode.Continuation, None, 5, fin = false)) + } + "126" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7e # length + xxxxxxxx + xxxxxxxx=007e # length16 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 126, fin = false)) + } + "127" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7e # length + xxxxxxxx + xxxxxxxx=007f # length16 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 127, fin = false)) + } + "127 < length < 65536" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7e # length + xxxxxxxx + xxxxxxxx=d28e # length16 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0xd28e, fin = false)) + } + "65535" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7e # length + xxxxxxxx + xxxxxxxx=ffff # length16 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0xffff, fin = false)) + } + "65536" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=0000000000010000 # length64 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0x10000, fin = false)) + } + "> 65536" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=0000000123456789 # length64 + """ should parseTo(FrameHeader(Opcode.Continuation, None, 0x123456789L, fin = false)) + } + "Long.MaxValue" in { + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=7fffffffffffffff # length64 + """ should parseTo(FrameHeader(Opcode.Continuation, None, Long.MaxValue, fin = false)) + } + } + } + + "a partial frame" in { + val header = + b"""0000 # flags + xxxx=1 # opcode + 0 # mask? + xxxxxxx=5 # length + """ + val data = ByteString("abc") + + (header ++ data) should parseTo( + FrameStart( + FrameHeader(Opcode.Text, None, 5, fin = false), + data)) + } + "a partial frame of total size > Int.MaxValue" in { + val header = + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=00000000ffffffff # length64 + """ + val data = ByteString("abc", "ASCII") + + Seq(header, data) should parseMultipleTo( + FrameStart(FrameHeader(Opcode.Continuation, None, 0xFFFFFFFFL, fin = false), ByteString.empty), + FrameData(data, lastPart = false)) + } + "a full frame" in { + val header = + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=5 # length + """ + val data = ByteString("abcde") + + (header ++ data) should parseTo( + FrameStart(FrameHeader(Opcode.Continuation, None, 5, fin = false), data)) + } + "a full frame in chunks" in { + val header = + b"""0000 # flags + xxxx=1 # opcode + 0 # mask? + xxxxxxx=5 # length + """ + val data1 = ByteString("abc") + val data2 = ByteString("de") + + val expectedHeader = FrameHeader(Opcode.Text, None, 5, fin = false) + Seq(header, data1, data2) should parseMultipleTo( + FrameStart(expectedHeader, ByteString.empty), + FrameData(data1, lastPart = false), + FrameData(data2, lastPart = true)) + } + "several frames" in { + val header1 = + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=5 # length + """ + + val header2 = + b"""0000 # flags + xxxx=0 # opcode + 0 # mask? + xxxxxxx=7 # length + """ + + val data1 = ByteString("abcde") + val data2 = ByteString("abc") + + (header1 ++ data1 ++ header2 ++ data2) should parseTo( + FrameStart(FrameHeader(Opcode.Continuation, None, 5, fin = false), data1), + FrameStart(FrameHeader(Opcode.Continuation, None, 7, fin = false), data2)) + } + } + + def parseTo(events: FrameEvent*): Matcher[ByteString] = + parseMultipleTo(events: _*).compose(Seq(_)) + + def parseMultipleTo(events: FrameEvent*): Matcher[Seq[ByteString]] = + equal(events).matcher[Seq[FrameEvent]].compose { + (chunks: Seq[ByteString]) ⇒ + val result = parseToEvents(chunks) + result shouldEqual events + val rendered = renderToByteString(result) + rendered shouldEqual chunks.reduce(_ ++ _) + result + } + + def parseToEvents(bytes: Seq[ByteString]): immutable.Seq[FrameEvent] = + Source(bytes.toVector).transform(newParser).runFold(Vector.empty[FrameEvent])(_ :+ _) + .awaitResult(1.second) + def renderToByteString(events: immutable.Seq[FrameEvent]): ByteString = + Source(events).transform(newRenderer).runFold(ByteString.empty)(_ ++ _) + .awaitResult(1.second) + + protected def newParser(): Stage[ByteString, FrameEvent] = new FrameEventParser + protected def newRenderer(): Stage[FrameEvent, ByteString] = new FrameEventRenderer + + import scala.language.implicitConversions + implicit def headerToEvent(header: FrameHeader): FrameEvent = + FrameStart(header, ByteString.empty) +} From 880733eb3d4cbe96c5346928f49e815af567f66e Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:19:32 +0200 Subject: [PATCH 09/13] =htc make headAndTail work if prefixAndTail returns an empty prefix --- .../src/main/scala/akka/http/util/package.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/util/package.scala b/akka-http-core/src/main/scala/akka/http/util/package.scala index 8cc9d5ecbe..0b179b8976 100644 --- a/akka-http-core/src/main/scala/akka/http/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/util/package.scala @@ -43,13 +43,21 @@ package object util { private[http] implicit class SourceWithHeadAndTail[T, Mat](val underlying: Source[Source[T, Any], Mat]) extends AnyVal { def headAndTail: Source[(T, Source[T, Unit]), Mat] = - underlying.map { _.prefixAndTail(1).map { case (prefix, tail) ⇒ (prefix.head, tail) } } + underlying.map { + _.prefixAndTail(1) + .filter(_._1.nonEmpty) + .map { case (prefix, tail) ⇒ (prefix.head, tail) } + } .flatten(FlattenStrategy.concat) } private[http] implicit class FlowWithHeadAndTail[In, Out, Mat](val underlying: Flow[In, Source[Out, Any], Mat]) extends AnyVal { def headAndTail: Flow[In, (Out, Source[Out, Unit]), Mat] = - underlying.map { _.prefixAndTail(1).map { case (prefix, tail) ⇒ (prefix.head, tail) } } + underlying.map { + _.prefixAndTail(1) + .filter(_._1.nonEmpty) + .map { case (prefix, tail) ⇒ (prefix.head, tail) } + } .flatten(FlattenStrategy.concat) } From 6dafa445deb7d4645cf65c80a1b35bf871f764ed Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:21:22 +0200 Subject: [PATCH 10/13] =htc #16887 implement high-level server-side Websocket API --- .../akka/http/engine/ws/FrameHandler.scala | 200 ++++ .../akka/http/engine/ws/FrameOutHandler.scala | 132 +++ .../scala/akka/http/engine/ws/Masking.scala | 71 ++ .../engine/ws/MessageToFrameRenderer.scala | 37 + .../scala/akka/http/engine/ws/Websocket.scala | 179 ++++ .../akka/http/engine/ws/MessageSpec.scala | 966 ++++++++++++++++++ 6 files changed, 1585 insertions(+) create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala new file mode 100644 index 0000000000..bde765f3d7 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.scaladsl.Flow +import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, StatefulStage } +import akka.util.ByteString +import Protocol.Opcode + +import scala.util.control.NonFatal + +/** + * The frame handler validates frames, multiplexes data to the user handler or to the bypass and + * UTF-8 decodes text frames. + * + * INTERNAL API + */ +private[http] object FrameHandler { + def create(server: Boolean): Flow[FrameEvent, Either[BypassEvent, MessagePart], Unit] = + Flow[FrameEvent].transform(() ⇒ new HandlerStage(server)) + + class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Either[BypassEvent, MessagePart]] { + type Ctx = Context[Either[BypassEvent, MessagePart]] + def initial: State = Idle + + object Idle extends StateWithControlFrameHandling { + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + (start.header.opcode, start.isFullMessage) match { + case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true)) + case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start) + case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start) + case x ⇒ protocolError() + } + } + + class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) { + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last) + } + class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) { + val decoder = Utf8Decoder.create() + + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = + TextMessagePart(decoder.decode(data, endOfInput = last).get, last) + } + + abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling { + var expectFirstHeader = true + var finSeen = false + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart + + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = { + if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected + || start.header.opcode == Opcode.Continuation) { // further ones continuations + expectFirstHeader = false + + if (start.header.fin) finSeen = true + publish(start) + } else protocolError() + } + override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data) + + private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective = + try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart)) + catch { + case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData) + } + } + class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState { + var data = _data + + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = { + this.data ++= data.data + if (data.lastPart) handleControlFrame(opcode, this.data, nextState) + else ctx.pull() + } + } + object Closed extends State { + def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective = + ctx.pull() // ignore + } + + def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { + become(newState) + current.onPush(part, ctx) + } + + /** Returns a SyncDirective if it handled the message */ + def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match { + case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError()) + case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError()) + case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError()) + case _ ⇒ None + } + + def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = { + become(nextState) + opcode match { + case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true)) + case Opcode.Pong ⇒ + // ignore unsolicited Pong frame + ctx.pull() + case Opcode.Close ⇒ + val closeCode = FrameEventParser.parseCloseCode(data) + emit(Iterator(Left(PeerClosed(closeCode)), Right(PeerClosed(closeCode))), ctx, WaitForPeerTcpClose) + case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode") + } + } + private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = { + assert(!start.isFullMessage) + become(new CollectingControlFrame(start.header.opcode, start.data, nextState)) + ctx.pull() + } + + private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective = + if (part.last) emit(Iterator(Right(part), Right(MessageEnd)), ctx, Idle) + else ctx.push(Right(part)) + private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective = + ctx.push(Left(DirectAnswer(frame))) + + private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective = + closeWithCode(Protocol.CloseCodes.ProtocolError, reason) + + private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = + emit( + Iterator( + Left(ActivelyCloseWithCode(Some(closeCode), reason)), + Right(ActivelyCloseWithCode(Some(closeCode), reason))), ctx, CloseAfterPeerClosed) + + object CloseAfterPeerClosed extends State { + def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective = + elem match { + case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ + become(WaitForPeerTcpClose) + ctx.push(Left(PeerClosed(FrameEventParser.parseCloseCode(data)))) + + case _ ⇒ ctx.pull() // ignore all other data + } + } + object WaitForPeerTcpClose extends State { + def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective = + ctx.pull() // ignore + } + + abstract class StateWithControlFrameHandling extends BetweenFrameState { + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective + + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + validateHeader(start.header).getOrElse { + if (start.header.opcode.isControl) + if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this) + else collectControlFrame(start, this) + else handleRegularFrameStart(start) + } + } + abstract class BetweenFrameState extends ImplicitContextState { + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = + throw new IllegalStateException("Expected FrameStart") + } + abstract class InFrameState extends ImplicitContextState { + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + throw new IllegalStateException("Expected FrameData") + } + abstract class ImplicitContextState extends State { + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective + + def onPush(part: FrameEvent, ctx: Ctx): SyncDirective = + part match { + case data: FrameData ⇒ handleFrameData(data)(ctx) + case start: FrameStart ⇒ handleFrameStart(start)(ctx) + } + } + } + + sealed trait MessagePart { + def isMessageEnd: Boolean + } + sealed trait MessageDataPart extends MessagePart { + def isMessageEnd = false + def last: Boolean + } + final case class TextMessagePart(data: String, last: Boolean) extends MessageDataPart + final case class BinaryMessagePart(data: ByteString, last: Boolean) extends MessageDataPart + case object MessageEnd extends MessagePart { + def isMessageEnd: Boolean = true + } + final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent { + def isMessageEnd: Boolean = true + } + + sealed trait BypassEvent + final case class DirectAnswer(frame: FrameStart) extends BypassEvent + final case class ActivelyCloseWithCode(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent { + def isMessageEnd: Boolean = true + } + case object UserHandlerCompleted extends BypassEvent + case class UserHandlerErredOut(cause: Throwable) extends BypassEvent +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala new file mode 100644 index 0000000000..4f560cca33 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import scala.concurrent.duration.FiniteDuration + +import akka.stream.stage._ +import akka.http.util.Timestamp +import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer } +import Websocket.Tick + +/** + * Implements the transport connection close handling at the end of the pipeline. + * + * INTERNAL API + */ +private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[AnyRef, FrameStart] { + def initial: StageState[AnyRef, FrameStart] = Idle + def closeTimeout: Timestamp = Timestamp.now + _closeTimeout + + object Idle extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case start: FrameStart ⇒ ctx.push(start) + case DirectAnswer(frame) ⇒ ctx.push(frame) + case PeerClosed(code, reason) if !code.exists(Protocol.CloseCodes.isError) ⇒ + // let user complete it, FIXME: maybe make configurable? immediately, or timeout + become(new WaitingForUserHandlerClosed(FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason))) + ctx.pull() + case PeerClosed(code, reason) ⇒ + val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason) + if (serverSide) ctx.pushAndFinish(closeFrame) + else { + become(new WaitingForTransportClose) + ctx.push(closeFrame) + } + case ActivelyCloseWithCode(code, reason) ⇒ + val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason) + become(new WaitingForPeerCloseFrame()) + ctx.push(closeFrame) + case UserHandlerCompleted ⇒ + become(new WaitingForPeerCloseFrame()) + ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular)) + case Tick ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = { + become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))) + ctx.absorbTermination() + } + } + + /** + * peer has closed, we want to wait for user handler to close as well + */ + class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case UserHandlerCompleted ⇒ + if (serverSide) ctx.pushAndFinish(closeFrame) + else { + become(new WaitingForTransportClose()) + ctx.push(closeFrame) + } + case start: FrameStart ⇒ ctx.push(start) + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = + ctx.fail(new IllegalStateException("Mustn't complete before user has completed")) + } + + /** + * we have sent out close frame and wait for peer to sent its close frame + */ + class WaitingForPeerCloseFrame(timeout: Timestamp = closeTimeout) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case Tick ⇒ + if (timeout.isPast) ctx.finish() + else ctx.pull() + case PeerClosed(code, reason) ⇒ + if (serverSide) ctx.finish() + else { + become(new WaitingForTransportClose()) + ctx.pull() + } + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish() + } + + /** + * Both side have sent their close frames, server should close the connection first + */ + class WaitingForTransportClose(timeout: Timestamp = closeTimeout) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case Tick ⇒ + if (timeout.isPast) ctx.finish() + else ctx.pull() + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish() + } + + /** If upstream has already failed we just wait to be able to deliver our close frame and complete */ + class SendOutCloseFrameAndComplete(closeFrame: FrameStart) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = + ctx.fail(new IllegalStateException("Didn't expect push after completion")) + + override def onPull(ctx: Context[FrameStart]): SyncDirective = + ctx.pushAndFinish(closeFrame) + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = + ctx.absorbTermination() + } + + trait CompletionHandlingState extends State { + def onComplete(ctx: Context[FrameStart]): TerminationDirective + } + + override def onUpstreamFinish(ctx: Context[FrameStart]): TerminationDirective = + current.asInstanceOf[CompletionHandlingState].onComplete(ctx) + + override def onUpstreamFailure(cause: scala.Throwable, ctx: Context[FrameStart]): TerminationDirective = cause match { + case p: ProtocolException ⇒ + become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.ProtocolError))) + ctx.absorbTermination() + case _ ⇒ super.onUpstreamFailure(cause, ctx) + } +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala new file mode 100644 index 0000000000..dbf7d85338 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.scaladsl.Flow +import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage } + +import scala.util.Random + +/** + * Implements Websocket Frame masking. + * + * INTERNAL API + */ +private[http] object Masking { + def maskIf(condition: Boolean, maskRandom: () ⇒ Random): Flow[FrameEvent, FrameEvent, Unit] = + if (condition) Flow[FrameEvent].transform(() ⇒ new Masking(maskRandom())) // new random per materialization + else Flow[FrameEvent] + def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] = + if (condition) Flow[FrameEvent].transform(() ⇒ new Unmasking()) + else Flow[FrameEvent] + + class Masking(random: Random) extends Masker { + def extractMask(header: FrameHeader): Int = random.nextInt() + def setNewMask(header: FrameHeader, mask: Int): FrameHeader = { + if (header.mask.isDefined) throw new ProtocolException("Frame mustn't already be masked") + header.copy(mask = Some(mask)) + } + } + class Unmasking extends Masker { + def extractMask(header: FrameHeader): Int = header.mask match { + case Some(mask) ⇒ mask + case None ⇒ throw new ProtocolException("Frame wasn't masked") + } + def setNewMask(header: FrameHeader, mask: Int): FrameHeader = header.copy(mask = None) + } + + /** Implements both masking and unmasking which is mostly symmetric (because of XOR) */ + abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] { + def extractMask(header: FrameHeader): Int + def setNewMask(header: FrameHeader, mask: Int): FrameHeader + + def initial: State = Idle + + object Idle extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = + part match { + case start @ FrameStart(header, data) ⇒ + if (header.length == 0) ctx.push(part) + else { + val mask = extractMask(header) + become(new Running(mask)) + current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + } + } + } + class Running(initialMask: Int) extends State { + var mask = initialMask + + def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = { + if (part.lastPart) become(Idle) + + val (masked, newMask) = FrameEventParser.mask(part.data, mask) + mask = newMask + ctx.push(part.withData(data = masked)) + } + } + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala new file mode 100644 index 0000000000..3d3c4c8d6b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.util.ByteString +import akka.stream.scaladsl.{ FlattenStrategy, Source, Flow } + +import Protocol.Opcode +import akka.http.model.ws._ + +/** + * Renders messages to full frames. + * + * INTERNAL API + */ +private[http] object MessageToFrameRenderer { + def create(serverSide: Boolean): Flow[Message, FrameStart, Unit] = { + def strictFrames(opcode: Opcode, data: ByteString): Source[FrameStart, _] = + // FIXME: fragment? + Source.single(FrameEvent.fullFrame(opcode, None, data, fin = true)) + + def streamedFrames(opcode: Opcode, data: Source[ByteString, _]): Source[FrameStart, _] = + Source.single(FrameEvent.empty(opcode, fin = false)) ++ + data.map(FrameEvent.fullFrame(Opcode.Continuation, None, _, fin = false)) ++ + Source.single(FrameEvent.emptyLastContinuationFrame) + + Flow[Message] + .map { + case BinaryMessage.Strict(data) ⇒ strictFrames(Opcode.Binary, data) + case BinaryMessage.Streamed(data) ⇒ streamedFrames(Opcode.Binary, data) + case TextMessage.Strict(text) ⇒ strictFrames(Opcode.Text, ByteString(text, "UTF-8")) + case TextMessage.Streamed(text) ⇒ streamedFrames(Opcode.Text, text.transform(() ⇒ new Utf8Encoder)) + }.flatten(FlattenStrategy.Concat()) + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala new file mode 100644 index 0000000000..e19827d0d1 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import java.security.SecureRandom + +import scala.concurrent.duration._ + +import akka.stream.{ OperationAttributes, FanOutShape2, FanInShape3, Inlet } +import akka.stream.scaladsl._ +import akka.stream.stage._ +import FlexiRoute.{ DemandFrom, DemandFromAny, RouteLogic } +import FlexiMerge.MergeLogic + +import akka.http.util._ +import akka.http.model.ws._ + +/** + * INTERNAL API + */ +private[http] object Websocket { + import FrameHandler._ + + def handleMessages[T](messageHandler: Flow[Message, Message, T], + serverSide: Boolean = true, + closeTimeout: FiniteDuration = 3.seconds): Flow[FrameEvent, FrameEvent, Unit] = { + /** Completes this branch of the flow if no more messages are expected and converts close codes into errors */ + class PrepareForUserHandler extends PushStage[MessagePart, MessagePart] { + def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match { + case PeerClosed(code, reason) ⇒ + if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code")) + else ctx.finish() + case ActivelyCloseWithCode(code, reason) ⇒ + if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Closing connection with error code $code")) + else ctx.finish() + case x ⇒ ctx.push(x) + } + } + + /** Collects user-level API messages from MessageDataParts */ + val collectMessage: Flow[Source[MessageDataPart, Unit], Message, Unit] = + Flow[Source[MessageDataPart, Unit]] + .headAndTail + .map { + case (TextMessagePart(text, true), remaining) ⇒ + TextMessage.Strict(text) + case (first @ TextMessagePart(text, false), remaining) ⇒ + TextMessage.Streamed( + (Source.single(first) ++ remaining) + .collect { + case t: TextMessagePart if t.data.nonEmpty ⇒ t.data + }) + case (BinaryMessagePart(data, true), remaining) ⇒ + BinaryMessage.Strict(data) + case (first @ BinaryMessagePart(data, false), remaining) ⇒ + BinaryMessage.Streamed( + (Source.single(first) ++ remaining) + .collect { + case t: BinaryMessagePart if t.data.nonEmpty ⇒ t.data + }) + } + + /** Lifts onComplete and onError into events to be processed in the FlexiMerge */ + class LiftCompletions extends StatefulStage[FrameStart, AnyRef] { + def initial: StageState[FrameStart, AnyRef] = SteadyState + + object SteadyState extends State { + def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = ctx.push(elem) + } + class CompleteWith(last: AnyRef) extends State { + def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = + ctx.fail(new IllegalStateException("No push expected")) + + override def onPull(ctx: Context[AnyRef]): SyncDirective = ctx.pushAndFinish(last) + } + + override def onUpstreamFinish(ctx: Context[AnyRef]): TerminationDirective = { + become(new CompleteWith(UserHandlerCompleted)) + ctx.absorbTermination() + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[AnyRef]): TerminationDirective = { + become(new CompleteWith(UserHandlerErredOut(cause))) + ctx.absorbTermination() + } + } + + lazy val userFlow = + Flow[MessagePart] + .transform(() ⇒ new PrepareForUserHandler) + .splitWhen(_.isMessageEnd) // FIXME using splitAfter from #16885 would simplify protocol a lot + .map(_.collect { + case m: MessageDataPart ⇒ m + }) + .via(collectMessage) + .via(messageHandler) + .via(MessageToFrameRenderer.create(serverSide)) + .transform(() ⇒ new LiftCompletions) + + /** + * Distributes output from the FrameHandler into bypass and userFlow. + */ + object BypassRouter + extends FlexiRoute[Either[BypassEvent, MessagePart], FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]](new FanOutShape2("bypassRouter"), OperationAttributes.name("bypassRouter")) { + def createRouteLogic(s: FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]): RouteLogic[Either[BypassEvent, MessagePart]] = + new RouteLogic[Either[BypassEvent, MessagePart]] { + def initialState: State[_] = State(DemandFromAny(s)) { (ctx, out, ev) ⇒ + ev match { + case Left(_) ⇒ + State(DemandFrom(s.out0)) { (ctx, _, ev) ⇒ // FIXME: #17004 + ctx.emit(s.out0)(ev.left.get) + initialState + } + case Right(_) ⇒ + State(DemandFrom(s.out1)) { (ctx, _, ev) ⇒ + ctx.emit(s.out1)(ev.right.get) + initialState + } + } + } + + override def initialCompletionHandling: CompletionHandling = super.initialCompletionHandling.copy( + onDownstreamFinish = { (ctx, out) ⇒ + if (out == s.out0) ctx.finish() + SameState + }) + } + } + /** + * Merges bypass, user flow and tick source for consumption in the FrameOutHandler. + */ + object BypassMerge extends FlexiMerge[AnyRef, FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]](new FanInShape3("bypassMerge"), OperationAttributes.name("bypassMerge")) { + def createMergeLogic(s: FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]): MergeLogic[AnyRef] = + new MergeLogic[AnyRef] { + def initialState: State[_] = Idle + + lazy val Idle = State[AnyRef](FlexiMerge.ReadAny(s.in0.asInstanceOf[Inlet[AnyRef]], s.in1.asInstanceOf[Inlet[AnyRef]], s.in2.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, elem) ⇒ + ctx.emit(elem) + SameState + } + + override def initialCompletionHandling: CompletionHandling = + CompletionHandling( + onUpstreamFinish = { (ctx, in) ⇒ + if (in == s.in0) ctx.finish() + SameState + }, + onUpstreamFailure = { (ctx, in, cause) ⇒ + if (in == s.in0) ctx.fail(cause) + SameState + }) + } + } + + lazy val bypassAndUserHandler: Flow[Either[BypassEvent, MessagePart], AnyRef, Unit] = + Flow(BypassRouter, Source(closeTimeout, closeTimeout, Tick), BypassMerge)((_, _, _) ⇒ ()) { implicit b ⇒ + (split, tick, merge) ⇒ + import FlowGraph.Implicits._ + + split.out0 ~> merge.in0 + split.out1 ~> userFlow ~> merge.in1 + tick.outlet ~> merge.in2 + + (split.in, merge.out) + } + + Flow[FrameEvent] + .via(Masking.unmaskIf(serverSide)) + .via(FrameHandler.create(server = serverSide)) + .mapConcat(x ⇒ x :: x :: Nil) // FIXME: #17004 + .via(bypassAndUserHandler) + .transform(() ⇒ new FrameOutHandler(serverSide, closeTimeout)) + .via(Masking.maskIf(!serverSide, () ⇒ new SecureRandom())) + } + + object Tick + case object SwitchToWebsocketToken +} diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala new file mode 100644 index 0000000000..bbe635db35 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala @@ -0,0 +1,966 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import scala.annotation.tailrec +import scala.concurrent.duration._ + +import akka.stream.FlowShape +import akka.stream.scaladsl._ +import akka.stream.testkit.StreamTestKit +import akka.util.ByteString +import org.scalatest.{ Matchers, FreeSpec } + +import scala.util.Random + +import akka.http.model.ws._ +import Protocol.Opcode + +class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { + "The Websocket implementation should" - { + "collect messages from frames" - { + "for binary messages" - { + "for an empty message" in new ClientTestSetup { + val input = frameHeader(Opcode.Binary, 0, fin = true) + + pushInput(input) + expectMessage(BinaryMessage.Strict(ByteString.empty)) + } + "for one complete, strict, single frame message" in new ClientTestSetup { + val data = ByteString("abcdef", "ASCII") + val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data + + pushInput(input) + expectMessage(BinaryMessage.Strict(data)) + } + "for a partial frame" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header = frameHeader(Opcode.Binary, 6, fin = true) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + } + "for a frame split up into parts" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header = frameHeader(Opcode.Binary, 6, fin = true) + + pushInput(header) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + pushInput(data1) + sub.expectNext(data1) + + val data2 = ByteString("def", "ASCII") + pushInput(data2) + sub.expectNext(data2) + sub.expectComplete() + } + + "for a message split into several frames" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header1 = frameHeader(Opcode.Binary, 3, fin = false) + + pushInput(header1 ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + + val header2 = frameHeader(Opcode.Continuation, 4, fin = true) + val data2 = ByteString("defg", "ASCII") + pushInput(header2 ++ data2) + sub.expectNext(data2) + sub.expectComplete() + } + "for several messages" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header1 = frameHeader(Opcode.Binary, 3, fin = false) + + pushInput(header1 ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + + val header2 = frameHeader(Opcode.Continuation, 4, fin = true) + val header3 = frameHeader(Opcode.Binary, 2, fin = true) + val data2 = ByteString("defg", "ASCII") + val data3 = ByteString("h") + pushInput(header2 ++ data2 ++ header3 ++ data3) + sub.expectNext(data2) + sub.expectComplete() + + val BinaryMessage.Streamed(dataSource2) = expectMessage() + val sub2 = StreamTestKit.SubscriberProbe[ByteString] + dataSource2.runWith(Sink(sub2)) + val s2 = sub2.expectSubscription() + s2.request(2) + sub2.expectNext(data3) + + val data4 = ByteString("i") + pushInput(data4) + sub2.expectNext(data4) + sub2.expectComplete() + } + "unmask masked input on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val (data, _) = maskedASCII("abcdef", mask) + val data1 = data.take(3) + val data2 = data.drop(3) + val header = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("abc", "ASCII")) + + pushInput(data2) + sub.expectNext(ByteString("def", "ASCII")) + sub.expectComplete() + } + } + "for text messages" - { + "empty message" in new ClientTestSetup { + val input = frameHeader(Opcode.Text, 0, fin = true) + + pushInput(input) + expectMessage(TextMessage.Strict("")) + } + "decode complete, strict frame from utf8" in new ClientTestSetup { + val msg = "äbcdef€\uffff" + val data = ByteString(msg, "UTF-8") + val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data + + pushInput(input) + expectMessage(TextMessage.Strict(msg)) + } + "decode utf8 as far as possible for partial frame" in new ClientTestSetup { + val msg = "bäcdef€" + val data = ByteString(msg, "UTF-8") + val data0 = data.slice(0, 2) + val data1 = data.slice(2, 5) + val data2 = data.slice(5, data.size) + val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data0 + + pushInput(input) + val TextMessage.Streamed(parts) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[String] + parts.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(4) + sub.expectNext("b") + + pushInput(data1) + sub.expectNext("äcd") + } + "decode utf8 with code point split across frames" in new ClientTestSetup { + val msg = "äbcdef€" + val data = ByteString(msg, "UTF-8") + val data0 = data.slice(0, 1) + val data1 = data.slice(1, data.size) + val header0 = frameHeader(Opcode.Text, data0.size, fin = false) + + pushInput(header0 ++ data0) + val TextMessage.Streamed(parts) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[String] + parts.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(4) + sub.expectNoMsg(100.millis) + + val header1 = frameHeader(Opcode.Continuation, data1.size, fin = true) + pushInput(header1 ++ data1) + sub.expectNext("äbcdef€") + } + "unmask masked input on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val (data, _) = maskedUTF8("äbcdef€", mask) + val data1 = data.take(3) + val data2 = data.drop(3) + val header = frameHeader(Opcode.Binary, data.size, fin = true, mask = Some(mask)) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("äb", "UTF-8")) + + pushInput(data2) + sub.expectNext(ByteString("cdef€", "UTF-8")) + sub.expectComplete() + } + } + } + "render frames from messages" - { + "for binary messages" - { + "for a short strict message" in new ServerTestSetup { + val data = ByteString("abcdef", "ASCII") + val msg = BinaryMessage.Strict(data) + netOutSub.request(5) + pushMessage(msg) + + expectFrameOnNetwork(Opcode.Binary, data, fin = true) + } + "for a strict message larger than configured maximum frame size" in pending + "for a streamed message" in new ServerTestSetup { + val data = ByteString("abcdefg", "ASCII") + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + netOutSub.request(6) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val data1 = data.take(3) + val data2 = data.drop(3) + + sub.sendNext(data1) + expectFrameOnNetwork(Opcode.Continuation, data1, fin = false) + + sub.sendNext(data2) + expectFrameOnNetwork(Opcode.Continuation, data2, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + "for a streamed message with a chunk being larger than configured maximum frame size" in pending + "and mask input on the client side" in new ClientTestSetup { + val data = ByteString("abcdefg", "ASCII") + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + netOutSub.request(7) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val data1 = data.take(3) + val data2 = data.drop(3) + + sub.sendNext(data1) + expectMaskedFrameOnNetwork(Opcode.Continuation, data1, fin = false) + + sub.sendNext(data2) + expectMaskedFrameOnNetwork(Opcode.Continuation, data2, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + } + "for text messages" - { + "for a short strict message" in new ServerTestSetup { + val text = "äbcdef" + val msg = TextMessage.Strict(text) + netOutSub.request(5) + pushMessage(msg) + + expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true) + } + "for a strict message larger than configured maximum frame size" in pending + "for a streamed message" in new ServerTestSetup { + val text = "äbcd€fg" + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(6) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false) + + val text1 = text.take(3) + val text1Bytes = ByteString(text1, "UTF-8") + val text2 = text.drop(3) + val text2Bytes = ByteString(text2, "UTF-8") + + sub.sendNext(text1) + expectFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false) + + sub.sendNext(text2) + expectFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + "for a streamed message don't convert half surrogate pairs naively" in new ServerTestSetup { + val gclef = "𝄞" + gclef.size shouldEqual 2 + + // split up the code point + val half1 = gclef.take(1) + val half2 = gclef.drop(1) + println(half1(0).toInt.toHexString) + println(half2(0).toInt.toHexString) + + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(6) + + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false) + sub.sendNext(half1) + + expectNoNetworkData() + sub.sendNext(half2) + expectFrameOnNetwork(Opcode.Continuation, ByteString(gclef, "utf8"), fin = false) + } + "for a streamed message with a chunk being larger than configured maximum frame size" in pending + "and mask input on the client side" in new ClientTestSetup { + val text = "abcdefg" + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(5) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameOnNetwork(Opcode.Text, ByteString.empty, fin = false) + + val text1 = text.take(3) + val text1Bytes = ByteString(text1, "UTF-8") + val text2 = text.drop(3) + val text2Bytes = ByteString(text2, "UTF-8") + + sub.sendNext(text1) + expectMaskedFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false) + + sub.sendNext(text2) + expectMaskedFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + } + } + "supply automatic low-level websocket behavior" - { + "respond to ping frames unmasking them on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1 + + pushInput(input) + netOutSub.request(5) + expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true) + } + "respond to ping frames masking them on the client side" in new ClientTestSetup { + val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef") + + pushInput(input) + netOutSub.request(5) + expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true) + } + "respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup { + // receive multi-frame message + // receive and handle interleaved ping frame + // concurrently send out messages from handler + val mask1 = Random.nextInt() + val input1 = frameHeader(Opcode.Binary, 3, fin = false, mask = Some(mask1)) ++ maskedASCII("123", mask1)._1 + pushInput(input1) + + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("123", "ASCII")) + + val outPub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(outPub)) + netOutSub.request(10) + pushMessage(msg) + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val outSub = outPub.expectSubscription() + val outData1 = ByteString("abc", "ASCII") + outSub.sendNext(outData1) + expectFrameOnNetwork(Opcode.Continuation, outData1, fin = false) + + val pingMask = Random.nextInt() + val pingData = maskedASCII("pling", pingMask)._1 + val pingData0 = pingData.take(3) + val pingData1 = pingData.drop(3) + pushInput(frameHeader(Opcode.Ping, 5, fin = true, mask = Some(pingMask)) ++ pingData0) + expectNoNetworkData() + pushInput(pingData1) + expectFrameOnNetwork(Opcode.Pong, ByteString("pling", "ASCII"), fin = true) + + val outData2 = ByteString("def", "ASCII") + outSub.sendNext(outData2) + expectFrameOnNetwork(Opcode.Continuation, outData2, fin = false) + + outSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + val mask2 = Random.nextInt() + val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1 + pushInput(input2) + sub.expectNext(ByteString("456", "ASCII")) + sub.expectComplete() + } + "don't respond to unsolicited pong frames" in new ClientTestSetup { + val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef") + pushInput(data) + netOutSub.request(5) + expectNoNetworkData() + } + } + "provide close behavior" - { + "after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup { + netInSub.expectRequest() + netOutSub.request(20) + messageOutSub.request(20) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + + netIn.expectNoMsg(1.second) // especially the cancellation not yet + expectNoNetworkData() + messageOutSub.sendComplete() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + netInSub.expectCancellation() + } + "after receiving close frame without close code" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 0, fin = true)) + messageIn.expectComplete() + + messageOutSub.sendComplete() + // especially mustn't be Procotol.CloseCodes.NoCodePresent + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + netInSub.expectCancellation() + } + "after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup { + netOutSub.request(20) + messageOutSub.request(20) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + + // sending another message is allowed before closing (inherently racy) + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + "after receiving regular close frame when fragmented message is still open" in pendingUntilFixed { + new ServerTestSetup { + netOutSub.request(10) + messageInSub.request(10) + + pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false)) + val BinaryMessage.Streamed(dataSource) = messageIn.expectNext() + val inSubscriber = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(inSubscriber)) + val inSub = inSubscriber.expectSubscription() + + val outData = ByteString("def", "ASCII") + val mask = Random.nextInt() + pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1) + inSub.request(5) + inSubscriber.expectNext(outData) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + inSubscriber.expectError() + // truncation of open message + + // sending another message is allowed before closing (inherently racy) + + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + } + "after receiving error close frame" in pending + "after peer closes connection without sending a close frame" in new ServerTestSetup { + netInSub.expectRequest() + netInSub.sendComplete() + + messageIn.expectComplete() + messageOutSub.sendComplete() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + "when user handler closes (simple)" in new ServerTestSetup { + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectNoMsg(1.second) // wait for peer to close regularly + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + netOut.expectComplete() + netInSub.expectCancellation() + } + "when user handler closes main stream and substream only afterwards" in new ServerTestSetup { + netOutSub.request(10) + messageInSub.request(10) + + // send half a message + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + messageOutSub.sendComplete() + expectNoNetworkData() // need to wait for substream to close + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectNoMsg(1.second) // wait for peer to close regularly + + val mask = Random.nextInt() + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + netOut.expectComplete() + netInSub.expectCancellation() + } + "if user handler fails" in pending + "if peer closes with invalid close frame" - { + "close code outside of the valid range" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x")) + + val error = messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + netOut.expectComplete() + netInSub.expectCancellation() + } + "close data of size 1" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x")) + + val error = messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + netOut.expectComplete() + netInSub.expectCancellation() + } + "reason is no valid utf8 data" in pending + } + "timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup { + netInSub.expectRequest() + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectComplete() + netInSub.expectCancellation() + } + "timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup { + netInSub.expectRequest() + + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true)) + expectProtocolErrorOnNetwork() + messageOutSub.sendComplete() + + netOut.expectComplete() + netInSub.expectCancellation() + } + "ignore frames peer sends after close frame" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + + pushInput(frameHeader(Opcode.Binary, 0, fin = true)) + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectComplete() + netInSub.expectCancellation() + } + } + "reject unexpected frames" - { + "reserved bits set" - { + "rsv1" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true)) + expectProtocolErrorOnNetwork() + } + "rsv2" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv2 = true)) + expectProtocolErrorOnNetwork() + } + "rsv3" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv3 = true)) + expectProtocolErrorOnNetwork() + } + } + "highest bit of 64-bit length is set" in new ServerTestSetup { + import BitBuilder._ + + val header = + b"""0000 # flags + xxxx=1 # opcode + 1 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=ffffffff + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=ffffffff # length64 + 00000000 + 00000000 + 00000000 + 00000000 # empty mask + """ + + pushInput(header) + expectProtocolErrorOnNetwork() + } + "control frame bigger than 125 bytes" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Ping, 126, fin = true, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "fragmented control frame" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Ping, 0, fin = false, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "unexpected continuation frame" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Continuation, 0, fin = false, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "unexpected data frame when waiting for continuation" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = false) ++ + frameHeader(Opcode.Binary, 0, fin = false)) + expectProtocolErrorOnNetwork() + } + "invalid utf8 encoding for single frame message" in new ClientTestSetup { + val data = ByteString( + (128 + 64).toByte, // start two byte sequence + 0 // but don't finish it + ) + + pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "invalid utf8 encoding for streamed frame" in new ClientTestSetup { + val data = ByteString( + (128 + 64).toByte, // start two byte sequence + 0 // but don't finish it + ) + + pushInput(frameHeader(Opcode.Text, 0, fin = false) ++ + frameHeader(Opcode.Continuation, 2, fin = true) ++ + data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "truncated utf8 encoding for single frame message" in new ClientTestSetup { + val data = ByteString("€", "UTF-8").take(1) // half a euro + pushInput(frameHeader(Opcode.Text, 1, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "truncated utf8 encoding for streamed frame" in new ClientTestSetup { + val data = ByteString("€", "UTF-8").take(1) // half a euro + pushInput(frameHeader(Opcode.Text, 0, fin = false) ++ + frameHeader(Opcode.Continuation, 1, fin = true) ++ + data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "half a surrogate pair in utf8 encoding for a strict frame" in new ClientTestSetup { + val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8 + pushInput(frameHeader(Opcode.Text, 3, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "half a surrogate pair in utf8 encoding for a streamed frame" in new ClientTestSetup { + val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8 + pushInput(frameHeader(Opcode.Text, 0, fin = false)) + pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data) + + messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "unmasked input on the server side" in new ServerTestSetup { + val data = ByteString("abcdef", "ASCII") + val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data + + pushInput(input) + expectProtocolErrorOnNetwork() + } + "masked input on the client side" in new ClientTestSetup { + val mask = Random.nextInt() + val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1 + + pushInput(input) + expectProtocolErrorOnNetwork() + } + } + "support per-message-compression extension" in pending + } + + class ServerTestSetup extends TestSetup { + protected def serverSide: Boolean = true + } + class ClientTestSetup extends TestSetup { + protected def serverSide: Boolean = false + } + abstract class TestSetup { + protected def serverSide: Boolean + protected def closeTimeout: FiniteDuration = 1.second + + val netIn = StreamTestKit.PublisherProbe[ByteString] + val netOut = StreamTestKit.SubscriberProbe[ByteString] + + val messageIn = StreamTestKit.SubscriberProbe[Message] + val messageOut = StreamTestKit.PublisherProbe[Message] + + val messageHandler: Flow[Message, Message, Unit] = + Flow.wrap { + FlowGraph.partial() { implicit b ⇒ + val in = b.add(Sink(messageIn)) + val out = b.add(Source(messageOut)) + + FlowShape[Message, Message](in, out) + } + } + + Source(netIn) + .via(printEvent("netIn")) + .transform(() ⇒ new FrameEventParser) + .via(Websocket.handleMessages(messageHandler, serverSide, closeTimeout = closeTimeout)) + .via(printEvent("frameRendererIn")) + .transform(() ⇒ new FrameEventRenderer) + .via(printEvent("frameRendererOut")) + .to(Sink(netOut)) + .run() + + val netInSub = netIn.expectSubscription() + val netOutSub = netOut.expectSubscription() + val messageOutSub = messageOut.expectSubscription() + val messageInSub = messageIn.expectSubscription() + + def pushInput(data: ByteString): Unit = { + // TODO: expect/handle request? + netInSub.sendNext(data) + } + def pushMessage(msg: Message): Unit = { + messageOutSub.sendNext(msg) + } + + def expectMessage(message: Message): Unit = { + messageInSub.request(1) + messageIn.expectNext(message) + } + def expectMessage(): Message = { + messageInSub.request(1) + messageIn.expectNext() + } + + var inBuffer = ByteString.empty + @tailrec final def expectNetworkData(bytes: Int): ByteString = + if (inBuffer.size >= bytes) { + val res = inBuffer.take(bytes) + inBuffer = inBuffer.drop(bytes) + res + } else { + netOutSub.request(1) + inBuffer ++= netOut.expectNext() + expectNetworkData(bytes) + } + + def expectNetworkData(data: ByteString): Unit = + expectNetworkData(data.size) shouldEqual data + + def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + expectFrameHeaderOnNetwork(opcode, data.size, fin) + expectNetworkData(data) + } + def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin) + val masked = maskedBytes(data, mask)._1 + expectNetworkData(masked) + } + + /** Returns the mask if any is available */ + def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = { + val (op, l, f, m) = expectFrameHeaderOnNetwork() + op shouldEqual opcode + l shouldEqual length + f shouldEqual fin + m + } + def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = { + val header = expectNetworkData(2) + + val fin = (header(0) & Protocol.FIN_MASK) != 0 + val op = header(0) & Protocol.OP_MASK + + val hasMask = (header(1) & Protocol.MASK_MASK) != 0 + val length7 = header(1) & Protocol.LENGTH_MASK + val length = length7 match { + case 126 ⇒ + val length16Bytes = expectNetworkData(2) + (length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0 + case 127 ⇒ + val length64Bytes = expectNetworkData(8) + (length64Bytes(0) & 0xff).toLong << 56 | + (length64Bytes(1) & 0xff).toLong << 48 | + (length64Bytes(2) & 0xff).toLong << 40 | + (length64Bytes(3) & 0xff).toLong << 32 | + (length64Bytes(4) & 0xff).toLong << 24 | + (length64Bytes(5) & 0xff).toLong << 16 | + (length64Bytes(6) & 0xff).toLong << 8 | + (length64Bytes(7) & 0xff).toLong << 0 + case x ⇒ x + } + val mask = + if (hasMask) { + val maskBytes = expectNetworkData(4) + val mask = + (maskBytes(0) & 0xff) << 24 | + (maskBytes(1) & 0xff) << 16 | + (maskBytes(2) & 0xff) << 8 | + (maskBytes(3) & 0xff) << 0 + Some(mask) + } else None + + (Opcode.forCode(op.toByte), length, fin, mask) + } + + def expectProtocolErrorOnNetwork(): Unit = expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + def expectCloseCodeOnNetwork(expectedCode: Int): Unit = { + val (opcode, length, true, mask) = expectFrameHeaderOnNetwork() + opcode shouldEqual Opcode.Close + length should be >= 2.toLong + + val rawData = expectNetworkData(length.toInt) + val data = mask match { + case Some(m) ⇒ FrameEventParser.mask(rawData, m)._1 + case None ⇒ rawData + } + + val code = ((data(0) & 0xff) << 8) | ((data(1) & 0xff) << 0) + code shouldEqual expectedCode + } + + def expectNoNetworkData(): Unit = + netOut.expectNoMsg(100.millis) + } + + def frameHeader( + opcode: Opcode, + length: Long, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): ByteString = { + def set(should: Boolean, mask: Int): Int = + if (should) mask else 0 + + val flags = + set(fin, Protocol.FIN_MASK) | + set(rsv1, Protocol.RSV1_MASK) | + set(rsv2, Protocol.RSV2_MASK) | + set(rsv3, Protocol.RSV3_MASK) + + val opcodeByte = opcode.code | flags + + require(length >= 0) + val (lengthByteComponent, lengthBytes) = + if (length < 126) (length.toByte, ByteString.empty) + else if (length < 65536) (126.toByte, shortBE(length.toInt)) + else throw new IllegalArgumentException("Only lengths < 65536 allowed in test") + + val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0 + val maskBytes = mask match { + case Some(mask) ⇒ intBE(mask) + case None ⇒ ByteString.empty + } + val lengthByte = lengthByteComponent | maskMask + ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes + } + def closeFrame(closeCode: Int, mask: Boolean): ByteString = + if (mask) { + val mask = Random.nextInt() + frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ + maskedBytes(shortBE(closeCode), mask)._1 + } else + frameHeader(Opcode.Close, 2, fin = true) ++ + shortBE(closeCode) + + def maskedASCII(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "ASCII"), mask) + def maskedUTF8(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "UTF-8"), mask) + def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) = + FrameEventParser.mask(bytes, mask) + + def shortBE(value: Int): ByteString = { + require(value >= 0 && value < 65536, s"Value wasn't in short range: $value") + ByteString( + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) + } + def intBE(value: Int): ByteString = + ByteString( + ((value >> 24) & 0xff).toByte, + ((value >> 16) & 0xff).toByte, + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) + + val trace = false // set to `true` for debugging purposes + def printEvent[T](marker: String): Flow[T, T, Unit] = + if (trace) akka.http.util.printEvent(marker) + else Flow[T] +} From 23b995214955eb4cc50cef23ff85978c27d58485 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:23:00 +0200 Subject: [PATCH 11/13] =htc #16887 integrate websocket pipeline into http server --- .../engine/parsing/HttpRequestParser.scala | 13 +- .../HttpResponseRendererFactory.scala | 10 +- .../engine/server/HttpServerBluePrint.scala | 146 +++++++++++++++++- .../scala/akka/http/engine/ws/Handshake.scala | 117 ++++++++++++++ .../ws/UpgradeToWebsocketLowLevel.scala | 32 ++++ .../UpgradeToWebsocketsResponseHeader.scala | 19 +++ .../akka/http/engine/ws/WebsocketSwitch.scala | 13 ++ .../scala/akka/http/util/StreamUtils.scala | 83 +++++++++- .../rendering/ResponseRendererSpec.scala | 2 +- 9 files changed, 426 insertions(+), 9 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketsResponseHeader.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/WebsocketSwitch.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala index b54cb89e6c..5b10f15ea5 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpRequestParser.scala @@ -5,6 +5,8 @@ package akka.http.engine.parsing import java.lang.{ StringBuilder ⇒ JStringBuilder } +import akka.http.engine.ws.Handshake + import scala.annotation.tailrec import akka.actor.ActorRef import akka.stream.OperationAttributes._ @@ -121,9 +123,18 @@ private[http] class HttpRequestParser(_settings: ParserSettings, if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) { def emitRequestStart(createEntity: Source[RequestOutput, Unit] ⇒ RequestEntity, headers: List[HttpHeader] = headers) = { - val allHeaders = + val allHeaders0 = if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers else headers + + val allHeaders = + if (method == HttpMethods.GET) { + Handshake.isWebsocketUpgrade(headers, hostHeaderPresent) match { + case Some(upgrade) ⇒ upgrade :: allHeaders0 + case None ⇒ allHeaders0 + } + } else allHeaders0 + emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion)) } diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala index 109a5ebfb7..2c6c6411d7 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala @@ -4,6 +4,8 @@ package akka.http.engine.rendering +import akka.http.engine.ws.{ WebsocketSwitch, UpgradeToWebsocketResponseHeader, Handshake } + import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.util.ByteString @@ -21,7 +23,8 @@ import headers._ */ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Server], responseHeaderSizeHint: Int, - log: LoggingAdapter) { + log: LoggingAdapter, + websocketSwitch: Option[WebsocketSwitch] = None) { private val renderDefaultServerHeader: Rendering ⇒ Unit = serverHeader match { @@ -149,6 +152,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser if (renderConnectionHeader) r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf + else if (connHeader != null && connHeader.hasUpgrade && websocketSwitch.isDefined) { + r ~~ connHeader ~~ CrLf + val websocketHeader = headers.collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } + websocketHeader.foreach(header ⇒ websocketSwitch.get.switchToWebsocket(header.handlerFlow)(header.mat)) + } if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } diff --git a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerBluePrint.scala index 32867ecd92..dad93cde5e 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerBluePrint.scala @@ -4,6 +4,13 @@ package akka.http.engine.server +import akka.stream.scaladsl.FlexiMerge.{ ReadAny, MergeLogic } +import akka.stream.scaladsl._ + +import akka.http.engine.ws._ +import akka.stream.scaladsl.FlexiRoute.{ DemandFrom, RouteLogic } +import org.reactivestreams.{ Subscriber, Publisher } + import scala.util.control.NonFatal import akka.util.ByteString import akka.event.LoggingAdapter @@ -18,6 +25,7 @@ import akka.http.engine.TokenSourceActor import akka.http.model._ import akka.http.util._ import ParserOutput._ +import akka.http.engine.ws.Websocket.{ SwitchToWebsocketToken } /** * INTERNAL API @@ -37,7 +45,8 @@ private[http] object HttpServerBluePrint { logParsingError(info withSummaryPrepended "Illegal request header", log, parserSettings.errorLoggingVerbosity) }) - val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log) + val ws = websocketPipeline + val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log, Some(ws)) @volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168 val oneHundredContinueSource = Source.actorPublisher[OneHundredContinue.type] { @@ -94,17 +103,43 @@ private[http] object HttpServerBluePrint { val bypassOneHundredContinueInput = bypassMerge.in1 val bypassApplicationInput = bypassMerge.in2 + // HTTP pipeline requestParsing.outlet ~> bypassFanout.in bypassMerge.out ~> renderer.inlet val requestsIn = (bypassFanout.out(0) ~> requestPreparation).outlet bypassFanout.out(1) ~> bypass ~> bypassInput oneHundredContinueSource ~> bypassOneHundredContinueInput + val http = FlowShape(requestParsing.inlet, renderer.outlet) + + // Websocket pipeline + val websocket = b.add(ws.flow) + + // protocol routing + val protocolRouter = b.add(new WebsocketSwitchRouter()) + val protocolMerge = b.add(new WebsocketMerge) + + protocolRouter.out0 ~> http ~> protocolMerge.in0 + protocolRouter.out1 ~> websocket ~> protocolMerge.in1 + + // protocol switching + val wsSwitchTokenMerge = b.add(new StreamUtils.EagerCloseMerge2[AnyRef]("protocolSwitchWsTokenMerge")) + val switchTokenBroadcast = b.add(Broadcast[SwitchToWebsocketToken.type](2)) + ws.switchSource ~> switchTokenBroadcast.in + switchTokenBroadcast.out(0) ~> wsSwitchTokenMerge.in1 + wsSwitchTokenMerge.out /*~> printEvent[AnyRef]("netIn")*/ ~> protocolRouter.in + switchTokenBroadcast.out(1) ~> protocolMerge.in2 + val netIn = wsSwitchTokenMerge.in0 + + val netOutPrint = b.add( /*printEvent[ByteString]("netOut")*/ Flow[ByteString]) + protocolMerge.out ~> netOutPrint.inlet + val netOut = netOutPrint.outlet BidiShape[HttpResponse, ByteString, ByteString, HttpRequest]( bypassApplicationInput, - renderer.outlet, - requestParsing.inlet, + netOut, + netIn, + requestsIn) } } @@ -240,4 +275,109 @@ private[http] object HttpServerBluePrint { case _ ⇒ ctx.fail(error) } } + + case class WebsocketSetup( + flow: Flow[ByteString, ByteString, Any], + publisherKey: StreamUtils.ReadableCell[Publisher[FrameEvent]], + subscriberKey: StreamUtils.ReadableCell[Subscriber[FrameEvent]], + switchSource: Source[SwitchToWebsocketToken.type, Any]) extends WebsocketSwitch { + @volatile var switchToWebsocketRef: Option[ActorRef] = None + + def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = { + // 1. fill processing hole in the websocket pipeline with user-provided handler + Source(publisherKey.value) + .via(handlerFlow) + .to(Sink(subscriberKey.value)) + .run() + + // 1. and 2. could be racy in which case incoming data could arrive because of 2. before + // the pipeline in 1. has been established. The `PublisherSink`, should then, however, backpressure + // until the subscriber has connected (i.e. 1. has run). + + // 2. flip the switch + switchToWebsocketRef.get ! TokenSourceActor.Trigger + } + } + def websocketPipeline: WebsocketSetup = { + val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]] + val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]] + + val sink = StreamUtils.oneTimePublisherSink[FrameEvent](sinkCell, "frameHandler.in") + val source = StreamUtils.oneTimeSubscriberSource[FrameEvent](sourceCell, "frameHandler.out") + + val flow = + Flow[ByteString] + .transform[FrameEvent](() ⇒ new FrameEventParser) + .via(Flow.wrap(sink, source)((_, _) ⇒ ())) + .transform(() ⇒ new FrameEventRenderer) + + lazy val setup = WebsocketSetup(flow, sinkCell, sourceCell, switchToWebsocketSource) + lazy val switchToWebsocketSource: Source[SwitchToWebsocketToken.type, ActorRef] = + Source.actorPublisher[SwitchToWebsocketToken.type] { + Props { + val actor = new TokenSourceActor(SwitchToWebsocketToken) + setup.switchToWebsocketRef = Some(actor.context.self) + actor + } + } + setup + } + class WebsocketSwitchRouter + extends FlexiRoute[AnyRef, FanOutShape2[AnyRef, ByteString, ByteString]](new FanOutShape2("websocketSplit"), OperationAttributes.name("websocketSplit")) { + + override def createRouteLogic(shape: FanOutShape2[AnyRef, ByteString, ByteString]): RouteLogic[AnyRef] = + new RouteLogic[AnyRef] { + def initialState: State[_] = http + + def http: State[_] = State[Any](DemandFrom(shape.out0)) { (ctx, _, element) ⇒ + element match { + case b: ByteString ⇒ + // route to HTTP processing + ctx.emit(shape.out0)(b) + SameState + + case SwitchToWebsocketToken ⇒ + // switch to websocket protocol + websockets + } + } + def websockets: State[_] = State[Any](DemandFrom(shape.out1)) { (ctx, _, element) ⇒ + // route to Websocket processing + ctx.emit(shape.out1)(element.asInstanceOf[ByteString]) + SameState + } + } + } + class WebsocketMerge extends FlexiMerge[ByteString, FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]](new FanInShape3("websocketMerge"), OperationAttributes.name("websocketMerge")) { + def createMergeLogic(s: FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]): MergeLogic[ByteString] = + new MergeLogic[ByteString] { + def httpIn = s.in0 + def wsIn = s.in1 + def tokenIn = s.in2 + + def initialState: State[_] = http + + def http: State[_] = State[AnyRef](ReadAny(httpIn.asInstanceOf[Inlet[AnyRef]], tokenIn.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, element) ⇒ + element match { + case b: ByteString ⇒ + ctx.emit(b); SameState + case SwitchToWebsocketToken ⇒ + ctx.changeCompletionHandling(closeWhenInCloses(wsIn)) + websockets + } + } + def websockets: State[_] = State[ByteString](ReadAny(httpIn /* otherwise we won't read the websocket upgrade response */ , wsIn)) { (ctx, _, element) ⇒ + ctx.emit(element) + SameState + } + + def closeWhenInCloses(in: Inlet[_]): CompletionHandling = + defaultCompletionHandling.copy(onUpstreamFinish = { (ctx, closingIn) ⇒ + if (closingIn == in) ctx.finish() + SameState + }) + + override def initialCompletionHandling: CompletionHandling = closeWhenInCloses(httpIn) + } + } } diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala new file mode 100644 index 0000000000..e57d07ffde --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.http.model.headers._ +import akka.http.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.model.{ StatusCodes, HttpResponse, HttpProtocol, HttpHeader } +import akka.parboiled2.util.Base64 +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +import scala.reflect.ClassTag + +/** + * Server-side implementation of the Websocket handshake + * + * INTERNAL API + */ +private[http] object Handshake { + val CurrentWebsocketVersion = 13 + + /* + From: http://tools.ietf.org/html/rfc6455#section-4.2.1 + + 1. An HTTP/1.1 or higher GET request, including a "Request-URI" + [RFC2616] that should be interpreted as a /resource name/ + defined in Section 3 (or an absolute HTTP/HTTPS URI containing + the /resource name/). + + 2. A |Host| header field containing the server's authority. + + 3. An |Upgrade| header field containing the value "websocket", + treated as an ASCII case-insensitive value. + + 4. A |Connection| header field that includes the token "Upgrade", + treated as an ASCII case-insensitive value. + + 5. A |Sec-WebSocket-Key| header field with a base64-encoded (see + Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + length. + + 6. A |Sec-WebSocket-Version| header field, with a value of 13. + + 7. Optionally, an |Origin| header field. This header field is sent + by all browser clients. A connection attempt lacking this + header field SHOULD NOT be interpreted as coming from a browser + client. + + 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list + of values indicating which protocols the client would like to + speak, ordered by preference. + + 9. Optionally, a |Sec-WebSocket-Extensions| header field, with a + list of values indicating which extensions the client would like + to speak. The interpretation of this header field is discussed + in Section 9.1. + */ + def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = { + def find[T <: HttpHeader: ClassTag]: Option[T] = + headers.collectFirst { + case t: T ⇒ t + } + + val host = find[Host] + val upgrade = find[Upgrade] + val connection = find[Connection] + val key = find[`Sec-WebSocket-Key`] + val version = find[`Sec-WebSocket-Version`] + val origin = find[Origin] + val protocol = find[`Sec-WebSocket-Protocol`] + val extensions = find[`Sec-WebSocket-Extensions`] + + def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 + + if (upgrade.exists(_.hasWebsocket) && + connection.exists(_.hasUpgrade) && + version.exists(_.hasVersion(CurrentWebsocketVersion)) && + key.exists(k ⇒ isValidKey(k.key))) { + + val header = new UpgradeToWebsocketLowLevel { + def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = + buildResponse(key.get, handlerFlow) + } + Some(header) + } else None + } + + /* + From: http://tools.ietf.org/html/rfc6455#section-4.2.2 + + 1. A Status-Line with a 101 response code as per RFC 2616 + [RFC2616]. Such a response could look like "HTTP/1.1 101 + Switching Protocols". + + 2. An |Upgrade| header field with value "websocket" as per RFC + 2616 [RFC2616]. + + 3. A |Connection| header field with value "Upgrade". + + 4. A |Sec-WebSocket-Accept| header field. The value of this + header field is constructed by concatenating /key/, defined + above in step 4 in Section 4.2.2, with the string "258EAFA5- + E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + concatenated value to obtain a 20-byte value and base64- + encoding (see Section 4 of [RFC4648]) this 20-byte hash. + */ + def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = + HttpResponse( + StatusCodes.SwitchingProtocols, + List( + Upgrade(List(UpgradeProtocol("websocket"))), + Connection(List("upgrade")), + `Sec-WebSocket-Accept`.forKey(key), + UpgradeToWebsocketResponseHeader(handlerFlow))) +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala new file mode 100644 index 0000000000..1c40e5d489 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.http.model.HttpResponse +import akka.http.model.ws.{ Message, UpgradeToWebsocket } +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +/** + * Currently internal API to handle FrameEvents directly. + * + * INTERNAL API + */ +private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHeader("UpgradeToWebsocket") with UpgradeToWebsocket { + /** + * The low-level interface to create Websocket server based on "frames". + * The user needs to handle control frames manually in this case. + * + * Returns a response to return in a request handler that will signal the + * low-level HTTP implementation to upgrade the connection to Websocket and + * use the supplied handler to handle incoming Websocket frames. + * + * INTERNAL API (for now) + */ + private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse + + override def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = + handleFrames(Websocket.handleMessages(handlerFlow)) +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketsResponseHeader.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketsResponseHeader.scala new file mode 100644 index 0000000000..594f06c7f4 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketsResponseHeader.scala @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.http.model.headers.CustomHeader +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +private[http] case class UpgradeToWebsocketResponseHeader(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit val mat: FlowMaterializer) + extends InternalCustomHeader("UpgradeToWebsocketResponseHeader") { +} + +private[http] abstract class InternalCustomHeader(val name: String) extends CustomHeader { + override def suppressRendering: Boolean = true + + def value(): String = "" +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/WebsocketSwitch.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/WebsocketSwitch.scala new file mode 100644 index 0000000000..91f81bd227 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/WebsocketSwitch.scala @@ -0,0 +1,13 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +/** Internal interface between the handshake and the stream setup to evoke the switch to the websocket protocol */ +private[http] trait WebsocketSwitch { + def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit +} diff --git a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala index db8a9e84fc..6503ab9fa1 100644 --- a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala @@ -5,13 +5,17 @@ package akka.http.util import java.io.InputStream -import java.util.concurrent.atomic.AtomicBoolean -import org.reactivestreams.Publisher +import java.util.concurrent.atomic.{ AtomicReference, AtomicBoolean } +import akka.stream.impl.StreamLayout.Module +import akka.stream.impl.{ SourceModule, SinkModule, ActorFlowMaterializerImpl, PublisherSink } +import akka.stream.scaladsl.FlexiMerge._ +import org.reactivestreams.{ Subscription, Processor, Subscriber, Publisher } +import scala.annotation.unchecked.uncheckedVariance import scala.collection.immutable import scala.concurrent.{ ExecutionContext, Future } import akka.util.ByteString import akka.http.model.RequestEntity -import akka.stream.{ FlowMaterializer, impl, OperationAttributes, ActorOperationAttributes } +import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage._ @@ -193,6 +197,79 @@ private[http] object StreamUtils { elem } } + + def oneTimePublisherSink[In](cell: OneTimeWriteCell[Publisher[In]], name: String): Sink[In, Publisher[In]] = + new Sink[In, Publisher[In]](new OneTimePublisherSink(none, SinkShape(new Inlet(name)), cell)) + def oneTimeSubscriberSource[Out](cell: OneTimeWriteCell[Subscriber[Out]], name: String): Source[Out, Subscriber[Out]] = + new Source[Out, Subscriber[Out]](new OneTimeSubscriberSource(none, SourceShape(new Outlet(name)), cell)) + + /** A copy of PublisherSink that allows access to the publisher through the cell but can only materialized once */ + private class OneTimePublisherSink[In](attributes: OperationAttributes, shape: SinkShape[In], cell: OneTimeWriteCell[Publisher[In]]) + extends PublisherSink[In](attributes, shape) { + override def create(context: MaterializationContext): (Subscriber[In], Publisher[In]) = { + val results = super.create(context) + cell.set(results._2) + results + } + override protected def newInstance(shape: SinkShape[In]): SinkModule[In, Publisher[In]] = + new OneTimePublisherSink[In](attributes, shape, cell) + + override def withAttributes(attr: OperationAttributes): Module = + new OneTimePublisherSink[In](attr, amendShape(attr), cell) + } + /** A copy of SubscriberSource that allows access to the subscriber through the cell but can only materialized once */ + private class OneTimeSubscriberSource[Out](val attributes: OperationAttributes, shape: SourceShape[Out], cell: OneTimeWriteCell[Subscriber[Out]]) + extends SourceModule[Out, Subscriber[Out]](shape) { + + override def create(context: MaterializationContext): (Publisher[Out], Subscriber[Out]) = { + val processor = new Processor[Out, Out] { + @volatile private var subscriber: Subscriber[_ >: Out] = null + + override def subscribe(s: Subscriber[_ >: Out]): Unit = subscriber = s + + override def onError(t: Throwable): Unit = subscriber.onError(t) + override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s) + override def onComplete(): Unit = subscriber.onComplete() + override def onNext(t: Out): Unit = subscriber.onNext(t) + } + cell.setValue(processor) + + (processor, processor) + } + + override protected def newInstance(shape: SourceShape[Out]): SourceModule[Out, Subscriber[Out]] = + new OneTimeSubscriberSource[Out](attributes, shape, cell) + override def withAttributes(attr: OperationAttributes): Module = + new OneTimeSubscriberSource[Out](attr, amendShape(attr), cell) + } + + trait ReadableCell[+T] { + def value: T + } + /** A one time settable cell */ + class OneTimeWriteCell[T <: AnyRef] extends AtomicReference[T] with ReadableCell[T] { + def value: T = { + val value = get() + require(value != null, "Value wasn't set yet") + value + } + + def setValue(value: T): Unit = + if (!compareAndSet(null.asInstanceOf[T], value)) + throw new IllegalStateException("Value can be only set once.") + } + + /** A merge for two streams that just forwards all elements and closes the connection eagerly. */ + class EagerCloseMerge2[T](name: String) extends FlexiMerge[T, FanInShape2[T, T, T]](new FanInShape2(name), OperationAttributes.name(name)) { + def createMergeLogic(s: FanInShape2[T, T, T]): MergeLogic[T] = + new MergeLogic[T] { + def initialState: State[T] = State[T](ReadAny(s.in0, s.in1)) { + case (ctx, port, in) ⇒ ctx.emit(in); SameState + } + + override def initialCompletionHandling: CompletionHandling = eagerClose + } + } } /** diff --git a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala index 9e2a4b5739..3077ec6cea 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala @@ -540,7 +540,7 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll override def afterAll() = system.shutdown() class TestSetup(val serverHeader: Option[Server] = Some(Server("akka-http/1.0.0"))) - extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging) { + extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging, None) { def renderTo(expected: String): Matcher[HttpResponse] = renderTo(expected, close = false) compose (ResponseRenderingContext(_)) From cd87dadf54c0903f2257cae6151bf4c3478bf936 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 16:28:56 +0200 Subject: [PATCH 12/13] +htp #16887 add simple websocket support directive to akka-http --- .../directives/WebsocketDirectivesSpec.scala | 45 +++++++++++++++++++ .../scala/akka/http/server/Directives.scala | 1 + .../scala/akka/http/server/Rejection.scala | 5 +++ .../akka/http/server/RejectionHandler.scala | 1 + .../directives/WebsocketDirectives.scala | 27 +++++++++++ 5 files changed, 79 insertions(+) create mode 100644 akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala create mode 100644 akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala new file mode 100644 index 0000000000..2de7c9ca95 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.server.directives + +import akka.http.engine.ws.InternalCustomHeader +import akka.http.model +import akka.http.model.headers.{ Connection, UpgradeProtocol, Upgrade } +import akka.http.model.{ HttpRequest, StatusCodes, HttpResponse } +import akka.http.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.server.{ Route, RoutingSpec } +import akka.http.util.Rendering +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +class WebsocketDirectivesSpec extends RoutingSpec { + "the handleWebsocketMessages directive" should { + "handle websocket requests" in { + Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~> + emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> + check { + status shouldEqual StatusCodes.SwitchingProtocols + } + } + "reject non-websocket requests" in { + Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check { + status shouldEqual StatusCodes.BadRequest + responseAs[String] shouldEqual "Expected Websocket Upgrade request" + } + } + } + + /** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */ + def emulateHttpCore(req: HttpRequest): HttpRequest = + req.header[Upgrade] match { + case Some(upgrade) if upgrade.hasWebsocket ⇒ req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock) + case _ ⇒ req + } + def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = + new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { + def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = + HttpResponse(StatusCodes.SwitchingProtocols) + } +} diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index 5b877d6025..24064c6c0c 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -32,5 +32,6 @@ trait Directives extends RouteConcatenation with RouteDirectives with SchemeDirectives with SecurityDirectives + with WebsocketDirectives object Directives extends Directives diff --git a/akka-http/src/main/scala/akka/http/server/Rejection.scala b/akka-http/src/main/scala/akka/http/server/Rejection.scala index 636d1854e2..ed89828dc7 100644 --- a/akka-http/src/main/scala/akka/http/server/Rejection.scala +++ b/akka-http/src/main/scala/akka/http/server/Rejection.scala @@ -163,6 +163,11 @@ case object AuthorizationFailedRejection extends Rejection */ case class MissingCookieRejection(cookieName: String) extends Rejection +/** + * Rejection created when a websocket request was expected but none was found. + */ +case object ExpectedWebsocketRequestRejection extends Rejection + /** * Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions` * thrown by domain model constructors (e.g. via `require`). diff --git a/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala index 828df9b73e..a467ce6b72 100644 --- a/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala +++ b/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala @@ -206,6 +206,7 @@ object RejectionHandler { val supported = rejections.map(_.supported.value).mkString(" or ") complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported) } + .handle { case ExpectedWebsocketRequestRejection ⇒ complete(BadRequest, "Expected Websocket Upgrade request") } .handle { case ValidationRejection(msg, _) ⇒ complete(BadRequest, msg) } .handle { case x ⇒ sys.error("Unhandled rejection: " + x) } .handleNotFound { complete(NotFound, "The requested resource could not be found.") } diff --git a/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala new file mode 100644 index 0000000000..7871a0f39e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model.ws.{ UpgradeToWebsocket, Message } +import akka.stream.scaladsl.Flow + +trait WebsocketDirectives { + import BasicDirectives._ + import RouteDirectives._ + import HeaderDirectives._ + + /** + * Handles websocket requests with the given handler and rejects other requests with a + * [[ExpectedWebsocketRequestRejection]]. + */ + def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route = + extractFlowMaterializer { implicit mat ⇒ + optionalHeaderValueByType[UpgradeToWebsocket]() { + case Some(upgrade) ⇒ complete(upgrade.handleMessages(handler)) + case None ⇒ reject(ExpectedWebsocketRequestRejection) + } + } +} From 6fef5d534c42ba6ed297239c4cababd285c35c39 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Wed, 22 Apr 2015 10:40:07 +0200 Subject: [PATCH 13/13] +htc #16887 add support for WS application-level subprotocol negotiation --- .../scala/akka/http/engine/ws/Handshake.scala | 24 ++++++++++++------- .../ws/UpgradeToWebsocketLowLevel.scala | 6 ++--- .../http/model/ws/UpgradeToWebsocket.scala | 15 ++++++++---- .../directives/WebsocketDirectivesSpec.scala | 6 ++++- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala index e57d07ffde..b6113d7ef5 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala @@ -11,6 +11,7 @@ import akka.parboiled2.util.Base64 import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow +import scala.collection.immutable.Seq import scala.reflect.ClassTag /** @@ -70,6 +71,7 @@ private[http] object Handshake { val version = find[`Sec-WebSocket-Version`] val origin = find[Origin] val protocol = find[`Sec-WebSocket-Protocol`] + val supportedProtocols = protocol.toList.flatMap(_.protocols) val extensions = find[`Sec-WebSocket-Extensions`] def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 @@ -80,8 +82,13 @@ private[http] object Handshake { key.exists(k ⇒ isValidKey(k.key))) { val header = new UpgradeToWebsocketLowLevel { - def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = - buildResponse(key.get, handlerFlow) + def requestedProtocols: Seq[String] = supportedProtocols + + def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = { + require(subprotocol.forall(chosen ⇒ supportedProtocols.contains(chosen)), + s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") + buildResponse(key.get, handlerFlow, subprotocol) + } } Some(header) } else None @@ -106,12 +113,13 @@ private[http] object Handshake { concatenated value to obtain a 20-byte value and base64- encoding (see Section 4 of [RFC4648]) this 20-byte hash. */ - def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = + def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = HttpResponse( StatusCodes.SwitchingProtocols, - List( - Upgrade(List(UpgradeProtocol("websocket"))), - Connection(List("upgrade")), - `Sec-WebSocket-Accept`.forKey(key), - UpgradeToWebsocketResponseHeader(handlerFlow))) + subprotocol.map(p ⇒ `Sec-WebSocket-Protocol`(Seq(p))).toList ::: + List( + Upgrade(List(UpgradeProtocol("websocket"))), + Connection(List("upgrade")), + `Sec-WebSocket-Accept`.forKey(key), + UpgradeToWebsocketResponseHeader(handlerFlow))) } diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala index 1c40e5d489..1408d41be6 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala @@ -25,8 +25,8 @@ private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHe * * INTERNAL API (for now) */ - private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse + private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse - override def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = - handleFrames(Websocket.handleMessages(handlerFlow)) + override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse = + handleFrames(Websocket.handleMessages(handlerFlow), subprotocol) } diff --git a/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala index c63b4ea2a4..7325d46835 100644 --- a/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala +++ b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala @@ -4,6 +4,7 @@ package akka.http.model.ws +import scala.collection.immutable import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow @@ -13,16 +14,22 @@ import akka.http.model.{ HttpHeader, HttpResponse } * A custom header that will be added to an Websocket upgrade HttpRequest that * enables a request handler to upgrade this connection to a Websocket connection and * registers a Websocket handler. - * - * FIXME: needs to be able to choose subprotocols as possibly agreed on in the websocket handshake */ trait UpgradeToWebsocket extends HttpHeader { + /** + * A sequence of protocols the client accepts. + * + * See http://tools.ietf.org/html/rfc6455#section-1.9 + */ + def requestedProtocols: immutable.Seq[String] + /** * The high-level interface to create a Websocket server based on "messages". * * Returns a response to return in a request handler that will signal the * low-level HTTP implementation to upgrade the connection to Websocket and - * use the supplied handler to handle incoming Websocket messages. + * use the supplied handler to handle incoming Websocket messages. Optionally, + * a subprotocol out of the ones requested by the client can be chosen. */ - def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse } diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala index 2de7c9ca95..e05c6ab17e 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala @@ -14,6 +14,8 @@ import akka.http.util.Rendering import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow +import scala.collection.immutable.Seq + class WebsocketDirectivesSpec extends RoutingSpec { "the handleWebsocketMessages directive" should { "handle websocket requests" in { @@ -39,7 +41,9 @@ class WebsocketDirectivesSpec extends RoutingSpec { } def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { - def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = + def requestedProtocols: Seq[String] = Nil + + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = HttpResponse(StatusCodes.SwitchingProtocols) } }