diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index 014fca5a24..e0909b08f8 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -6,14 +6,10 @@ package akka.http.impl.engine.ws import java.util.Random -import akka.http.impl.engine.parsing.ParserOutput.MessageStartError - import scala.collection.immutable import scala.collection.immutable.Seq import scala.reflect.ClassTag -import akka.parboiled2.util.Base64 - import akka.stream.scaladsl.Flow import akka.http.impl.util._ @@ -91,12 +87,10 @@ private[http] object Handshake { // FIXME See #18709 // val extensions = find[`Sec-WebSocket-Extensions`] - def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 - if (upgrade.exists(_.hasWebsocket) && connection.exists(_.hasUpgrade) && version.exists(_.hasVersion(CurrentWebsocketVersion)) && - key.exists(k ⇒ isValidKey(k.key))) { + key.exists(k ⇒ k.isValid)) { val header = new UpgradeToWebsocketLowLevel { def requestedProtocols: Seq[String] = clientSupportedSubprotocols @@ -156,7 +150,7 @@ private[http] object Handshake { def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = { val keyBytes = new Array[Byte](16) random.nextBytes(keyBytes) - val key = `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) + val key = `Sec-WebSocket-Key`(keyBytes) val protocol = if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil else Nil diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala index 41092abab4..46fbb83522 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala @@ -8,15 +8,18 @@ import java.lang.Iterable import java.net.InetSocketAddress import java.security.MessageDigest import java.util + +import scala.reflect.ClassTag +import scala.util.Try import scala.annotation.tailrec import scala.collection.immutable + import akka.parboiled2.util.Base64 + import akka.http.impl.util._ import akka.http.javadsl.{ model ⇒ jm } import akka.http.scaladsl.model._ -import scala.reflect.ClassTag - sealed abstract class ModeledCompanion[T: ClassTag] extends Renderable { val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $ val lowercaseName = name.toRootLowerCase @@ -627,7 +630,12 @@ private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable. /** * INTERNAL API */ -private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] +private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] { + def apply(keyBytes: Array[Byte]): `Sec-WebSocket-Key` = { + require(keyBytes.length == 16, s"Sec-WebSocket-Key keyBytes must have length 16 but had ${keyBytes.length}") + `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) + } +} /** * INTERNAL API */ @@ -635,6 +643,12 @@ private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledH protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key protected def companion = `Sec-WebSocket-Key` + + /** + * Checks if the key value is valid according to the Websocket specification, i.e. + * if the String is a Base64 representation of 16 bytes. + */ + def isValid: Boolean = Try(Base64.rfc2045().decode(key)).toOption.exists(_.length == 16) } // http://tools.ietf.org/html/rfc6455#section-4.3