=htc move Sec-WebSocket-Key creation/validation to header model

This commit is contained in:
Johannes Rudolph 2015-10-19 09:17:06 +02:00
parent f6732f3369
commit ddc8cd804b
2 changed files with 19 additions and 11 deletions

View file

@ -6,14 +6,10 @@ package akka.http.impl.engine.ws
import java.util.Random import java.util.Random
import akka.http.impl.engine.parsing.ParserOutput.MessageStartError
import scala.collection.immutable import scala.collection.immutable
import scala.collection.immutable.Seq import scala.collection.immutable.Seq
import scala.reflect.ClassTag import scala.reflect.ClassTag
import akka.parboiled2.util.Base64
import akka.stream.scaladsl.Flow import akka.stream.scaladsl.Flow
import akka.http.impl.util._ import akka.http.impl.util._
@ -91,12 +87,10 @@ private[http] object Handshake {
// FIXME See #18709 // FIXME See #18709
// val extensions = find[`Sec-WebSocket-Extensions`] // val extensions = find[`Sec-WebSocket-Extensions`]
def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16
if (upgrade.exists(_.hasWebsocket) && if (upgrade.exists(_.hasWebsocket) &&
connection.exists(_.hasUpgrade) && connection.exists(_.hasUpgrade) &&
version.exists(_.hasVersion(CurrentWebsocketVersion)) && version.exists(_.hasVersion(CurrentWebsocketVersion)) &&
key.exists(k isValidKey(k.key))) { key.exists(k k.isValid)) {
val header = new UpgradeToWebsocketLowLevel { val header = new UpgradeToWebsocketLowLevel {
def requestedProtocols: Seq[String] = clientSupportedSubprotocols def requestedProtocols: Seq[String] = clientSupportedSubprotocols
@ -156,7 +150,7 @@ private[http] object Handshake {
def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = { def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = {
val keyBytes = new Array[Byte](16) val keyBytes = new Array[Byte](16)
random.nextBytes(keyBytes) random.nextBytes(keyBytes)
val key = `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) val key = `Sec-WebSocket-Key`(keyBytes)
val protocol = val protocol =
if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil
else Nil else Nil

View file

@ -8,15 +8,18 @@ import java.lang.Iterable
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.security.MessageDigest import java.security.MessageDigest
import java.util import java.util
import scala.reflect.ClassTag
import scala.util.Try
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.collection.immutable import scala.collection.immutable
import akka.parboiled2.util.Base64 import akka.parboiled2.util.Base64
import akka.http.impl.util._ import akka.http.impl.util._
import akka.http.javadsl.{ model jm } import akka.http.javadsl.{ model jm }
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import scala.reflect.ClassTag
sealed abstract class ModeledCompanion[T: ClassTag] extends Renderable { sealed abstract class ModeledCompanion[T: ClassTag] extends Renderable {
val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $ val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $
val lowercaseName = name.toRootLowerCase val lowercaseName = name.toRootLowerCase
@ -627,7 +630,12 @@ private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable.
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] {
def apply(keyBytes: Array[Byte]): `Sec-WebSocket-Key` = {
require(keyBytes.length == 16, s"Sec-WebSocket-Key keyBytes must have length 16 but had ${keyBytes.length}")
`Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false))
}
}
/** /**
* INTERNAL API * INTERNAL API
*/ */
@ -635,6 +643,12 @@ private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledH
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key
protected def companion = `Sec-WebSocket-Key` protected def companion = `Sec-WebSocket-Key`
/**
* Checks if the key value is valid according to the Websocket specification, i.e.
* if the String is a Base64 representation of 16 bytes.
*/
def isValid: Boolean = Try(Base64.rfc2045().decode(key)).toOption.exists(_.length == 16)
} }
// http://tools.ietf.org/html/rfc6455#section-4.3 // http://tools.ietf.org/html/rfc6455#section-4.3