=htp #17067 port XXE fixes and tests from spray/spray#1024

This commit is contained in:
Johannes Rudolph 2015-06-05 17:05:27 +02:00
parent 71df8f810a
commit 0018bc2cda
4 changed files with 126 additions and 18 deletions

View file

@ -5,6 +5,7 @@
package akka.http.scaladsl.marshallers.xml package akka.http.scaladsl.marshallers.xml
import java.io.{ ByteArrayInputStream, InputStreamReader } import java.io.{ ByteArrayInputStream, InputStreamReader }
import javax.xml.parsers.{ SAXParserFactory, SAXParser }
import scala.collection.immutable import scala.collection.immutable
import scala.xml.{ XML, NodeSeq } import scala.xml.{ XML, NodeSeq }
import akka.stream.FlowMaterializer import akka.stream.FlowMaterializer
@ -26,15 +27,38 @@ trait ScalaXmlSupport {
def nodeSeqUnmarshaller(ranges: ContentTypeRange*)(implicit fm: FlowMaterializer): FromEntityUnmarshaller[NodeSeq] = def nodeSeqUnmarshaller(ranges: ContentTypeRange*)(implicit fm: FlowMaterializer): FromEntityUnmarshaller[NodeSeq] =
Unmarshaller.byteArrayUnmarshaller.forContentTypes(ranges: _*).mapWithCharset { (bytes, charset) Unmarshaller.byteArrayUnmarshaller.forContentTypes(ranges: _*).mapWithCharset { (bytes, charset)
if (bytes.length > 0) { if (bytes.length > 0) {
val parser = XML.parser
try parser.setProperty("http://apache.org/xml/properties/locale", java.util.Locale.ROOT)
catch { case e: org.xml.sax.SAXNotRecognizedException /* property is not needed */ }
val reader = new InputStreamReader(new ByteArrayInputStream(bytes), charset.nioCharset) val reader = new InputStreamReader(new ByteArrayInputStream(bytes), charset.nioCharset)
XML.withSAXParser(parser).load(reader): NodeSeq // blocking call! Ideally we'd have a `loadToFuture` XML.withSAXParser(createSAXParser()).load(reader): NodeSeq // blocking call! Ideally we'd have a `loadToFuture`
} else NodeSeq.Empty } else NodeSeq.Empty
} }
/**
* Provides a SAXParser for the NodeSeqUnmarshaller to use. Override to provide a custom SAXParser implementation.
* Will be called once for for every request to be unmarshalled. The default implementation calls [[ScalaXmlSupport.createSaferSAXParser]].
* @return
*/
protected def createSAXParser(): SAXParser = ScalaXmlSupport.createSaferSAXParser()
} }
object ScalaXmlSupport extends ScalaXmlSupport { object ScalaXmlSupport extends ScalaXmlSupport {
val nodeSeqContentTypes: immutable.Seq[ContentType] = List(`text/xml`, `application/xml`, `text/html`, `application/xhtml+xml`) val nodeSeqContentTypes: immutable.Seq[ContentType] = List(`text/xml`, `application/xml`, `text/html`, `application/xhtml+xml`)
val nodeSeqContentTypeRanges: immutable.Seq[ContentTypeRange] = nodeSeqContentTypes.map(ContentTypeRange(_)) val nodeSeqContentTypeRanges: immutable.Seq[ContentTypeRange] = nodeSeqContentTypes.map(ContentTypeRange(_))
/** Creates a safer SAXParser. */
def createSaferSAXParser(): SAXParser = {
val factory = SAXParserFactory.newInstance()
import com.sun.org.apache.xerces.internal.impl.Constants
import javax.xml.XMLConstants
factory.setFeature(Constants.SAX_FEATURE_PREFIX + Constants.EXTERNAL_GENERAL_ENTITIES_FEATURE, false)
factory.setFeature(Constants.SAX_FEATURE_PREFIX + Constants.EXTERNAL_PARAMETER_ENTITIES_FEATURE, false)
factory.setFeature(Constants.XERCES_FEATURE_PREFIX + Constants.DISALLOW_DOCTYPE_DECL_FEATURE, true)
factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true)
val parser = factory.newSAXParser()
try {
parser.setProperty("http://apache.org/xml/properties/locale", java.util.Locale.ROOT)
} catch {
case e: org.xml.sax.SAXNotRecognizedException // property is not needed
}
parser
}
} }

View file

@ -0,0 +1,16 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.scaladsl
import java.io.{ FileOutputStream, File }
object TestUtils {
def writeAllText(text: String, file: File): Unit = {
val fos = new FileOutputStream(file)
try {
fos.write(text.getBytes("UTF-8"))
} finally fos.close()
}
}

View file

@ -4,28 +4,102 @@
package akka.http.scaladsl.marshallers.xml package akka.http.scaladsl.marshallers.xml
import java.io.File
import akka.http.scaladsl.TestUtils
import scala.concurrent.duration._
import org.xml.sax.SAXParseException
import scala.concurrent.{ Future, Await }
import scala.xml.NodeSeq import scala.xml.NodeSeq
import org.scalatest.{ Matchers, WordSpec } import org.scalatest.{ Inside, FreeSpec, Matchers }
import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.testkit.ScalatestRouteTest
import akka.http.scaladsl.unmarshalling.{ Unmarshaller, Unmarshal } import akka.http.scaladsl.unmarshalling.{ Unmarshaller, Unmarshal }
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import HttpCharsets._ import HttpCharsets._
import MediaTypes._ import MediaTypes._
class ScalaXmlSupportSpec extends WordSpec with Matchers with ScalatestRouteTest { class ScalaXmlSupportSpec extends FreeSpec with Matchers with ScalatestRouteTest with Inside {
import ScalaXmlSupport._ import ScalaXmlSupport._
"ScalaXmlSupport" should { "NodeSeqMarshaller should" - {
"NodeSeqMarshaller should marshal xml snippets to `text/xml` content in UTF-8" in { "marshal xml snippets to `text/xml` content in UTF-8" in {
marshal(<employee><nr>Hallo</nr></employee>) shouldEqual marshal(<employee><nr>Hallo</nr></employee>) shouldEqual
HttpEntity(ContentType(`text/xml`, `UTF-8`), "<employee><nr>Ha“llo</nr></employee>") HttpEntity(ContentType(`text/xml`, `UTF-8`), "<employee><nr>Ha“llo</nr></employee>")
} }
"nodeSeqUnmarshaller should unmarshal `text/xml` content in UTF-8 to NodeSeqs" in { "unmarshal `text/xml` content in UTF-8 to NodeSeqs" in {
Unmarshal(HttpEntity(`text/xml`, "<int>Hällö</int>")).to[NodeSeq].map(_.text) should evaluateTo("Hällö") Unmarshal(HttpEntity(`text/xml`, "<int>Hällö</int>")).to[NodeSeq].map(_.text) should evaluateTo("Hällö")
} }
"nodeSeqUnmarshaller should reject `application/octet-stream`" in { "reject `application/octet-stream`" in {
Unmarshal(HttpEntity(`application/octet-stream`, "<int>Hällö</int>")).to[NodeSeq].map(_.text) should Unmarshal(HttpEntity(`application/octet-stream`, "<int>Hällö</int>")).to[NodeSeq].map(_.text) should
haveFailedWith(Unmarshaller.UnsupportedContentTypeException(nodeSeqContentTypeRanges: _*)) haveFailedWith(Unmarshaller.UnsupportedContentTypeException(nodeSeqContentTypeRanges: _*))
} }
"don't be vulnerable to XXE attacks" - {
"parse XML bodies without loading in a related schema" in {
withTempFile("I shouldn't be there!") { f
val xml = s"""<?xml version="1.0" encoding="ISO-8859-1"?>
| <!DOCTYPE foo [
| <!ELEMENT foo ANY >
| <!ENTITY xxe SYSTEM "${f.toURI}">]><foo>hello&xxe;</foo>""".stripMargin
shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq])
}
}
"parse XML bodies without loading in a related schema from a parameter" in {
withTempFile("I shouldnt be there!") { generalEntityFile
withTempFile {
s"""<!ENTITY % xge SYSTEM "${generalEntityFile.toURI}">
|<!ENTITY % pe "<!ENTITY xxe '%xge;'>">""".stripMargin
} { parameterEntityFile
val xml = s"""<?xml version="1.0" encoding="ISO-8859-1"?>
| <!DOCTYPE foo [
| <!ENTITY % xpe SYSTEM "${parameterEntityFile.toURI}">
| %xpe;
| %pe;
| ]><foo>hello&xxe;</foo>""".stripMargin
shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq])
}
}
}
"gracefully fail when there are too many nested entities" in {
val nested = for (x 1 to 30) yield "<!ENTITY laugh" + x + " \"&laugh" + (x - 1) + ";&laugh" + (x - 1) + ";\">"
val xml =
s"""<?xml version="1.0"?>
| <!DOCTYPE billion [
| <!ELEMENT billion (#PCDATA)>
| <!ENTITY laugh0 "ha">
| ${nested.mkString("\n")}
| ]>
| <billion>&laugh30;</billion>""".stripMargin
shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq])
}
"gracefully fail when an entity expands to be very large" in {
val as = "a" * 50000
val entities = "&a;" * 50000
val xml = s"""<?xml version="1.0"?>
| <!DOCTYPE kaboom [
| <!ENTITY a "$as">
| ]>
| <kaboom>$entities</kaboom>""".stripMargin
shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq])
}
}
}
def shouldHaveFailedWithSAXParseException(result: Future[NodeSeq]) =
inside(Await.result(result.failed, 1.second)) {
case _: SAXParseException
}
def withTempFile[T](content: String)(f: File T): T = {
val file = File.createTempFile("xxe", ".txt")
try {
TestUtils.writeAllText(content, file)
f(file)
} finally {
file.delete()
}
} }
} }

View file

@ -5,7 +5,7 @@
package akka.http.scaladsl.server package akka.http.scaladsl.server
package directives package directives
import java.io.{ File, FileOutputStream } import java.io.File
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.concurrent.{ ExecutionContext, Future } import scala.concurrent.{ ExecutionContext, Future }
import scala.util.Properties import scala.util.Properties
@ -15,6 +15,7 @@ import akka.http.scaladsl.model.MediaTypes._
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.headers._
import akka.http.impl.util._ import akka.http.impl.util._
import akka.http.scaladsl.TestUtils.writeAllText
class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Inside { class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Inside {
@ -356,13 +357,6 @@ class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Ins
def prep(s: String) = s.stripMarginWithNewline("\n") def prep(s: String) = s.stripMarginWithNewline("\n")
def writeAllText(text: String, file: File): Unit = {
val fos = new FileOutputStream(file)
try {
fos.write(text.getBytes("UTF-8"))
} finally fos.close()
}
def evaluateTo[T](t: T, atMost: Duration = 100.millis)(implicit ec: ExecutionContext): Matcher[Future[T]] = def evaluateTo[T](t: T, atMost: Duration = 100.millis)(implicit ec: ExecutionContext): Matcher[Future[T]] =
be(t).compose[Future[T]] { fut be(t).compose[Future[T]] { fut
fut.awaitResult(atMost) fut.awaitResult(atMost)