diff --git a/akka-http-marshallers-scala/akka-http-xml/src/main/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupport.scala b/akka-http-marshallers-scala/akka-http-xml/src/main/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupport.scala index 096f08401d..dd9bf20235 100644 --- a/akka-http-marshallers-scala/akka-http-xml/src/main/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupport.scala +++ b/akka-http-marshallers-scala/akka-http-xml/src/main/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupport.scala @@ -5,6 +5,7 @@ package akka.http.scaladsl.marshallers.xml import java.io.{ ByteArrayInputStream, InputStreamReader } +import javax.xml.parsers.{ SAXParserFactory, SAXParser } import scala.collection.immutable import scala.xml.{ XML, NodeSeq } import akka.stream.FlowMaterializer @@ -26,15 +27,38 @@ trait ScalaXmlSupport { def nodeSeqUnmarshaller(ranges: ContentTypeRange*)(implicit fm: FlowMaterializer): FromEntityUnmarshaller[NodeSeq] = Unmarshaller.byteArrayUnmarshaller.forContentTypes(ranges: _*).mapWithCharset { (bytes, charset) ⇒ if (bytes.length > 0) { - val parser = XML.parser - try parser.setProperty("http://apache.org/xml/properties/locale", java.util.Locale.ROOT) - catch { case e: org.xml.sax.SAXNotRecognizedException ⇒ /* property is not needed */ } val reader = new InputStreamReader(new ByteArrayInputStream(bytes), charset.nioCharset) - XML.withSAXParser(parser).load(reader): NodeSeq // blocking call! Ideally we'd have a `loadToFuture` + XML.withSAXParser(createSAXParser()).load(reader): NodeSeq // blocking call! Ideally we'd have a `loadToFuture` } else NodeSeq.Empty } + + /** + * Provides a SAXParser for the NodeSeqUnmarshaller to use. Override to provide a custom SAXParser implementation. + * Will be called once for for every request to be unmarshalled. The default implementation calls [[ScalaXmlSupport.createSaferSAXParser]]. + * @return + */ + protected def createSAXParser(): SAXParser = ScalaXmlSupport.createSaferSAXParser() } object ScalaXmlSupport extends ScalaXmlSupport { val nodeSeqContentTypes: immutable.Seq[ContentType] = List(`text/xml`, `application/xml`, `text/html`, `application/xhtml+xml`) val nodeSeqContentTypeRanges: immutable.Seq[ContentTypeRange] = nodeSeqContentTypes.map(ContentTypeRange(_)) + + /** Creates a safer SAXParser. */ + def createSaferSAXParser(): SAXParser = { + val factory = SAXParserFactory.newInstance() + import com.sun.org.apache.xerces.internal.impl.Constants + import javax.xml.XMLConstants + + factory.setFeature(Constants.SAX_FEATURE_PREFIX + Constants.EXTERNAL_GENERAL_ENTITIES_FEATURE, false) + factory.setFeature(Constants.SAX_FEATURE_PREFIX + Constants.EXTERNAL_PARAMETER_ENTITIES_FEATURE, false) + factory.setFeature(Constants.XERCES_FEATURE_PREFIX + Constants.DISALLOW_DOCTYPE_DECL_FEATURE, true) + factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true) + val parser = factory.newSAXParser() + try { + parser.setProperty("http://apache.org/xml/properties/locale", java.util.Locale.ROOT) + } catch { + case e: org.xml.sax.SAXNotRecognizedException ⇒ // property is not needed + } + parser + } } \ No newline at end of file diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/TestUtils.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/TestUtils.scala new file mode 100644 index 0000000000..7eb2da3e9f --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/TestUtils.scala @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl + +import java.io.{ FileOutputStream, File } + +object TestUtils { + def writeAllText(text: String, file: File): Unit = { + val fos = new FileOutputStream(file) + try { + fos.write(text.getBytes("UTF-8")) + } finally fos.close() + } +} diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala index 8592390e76..803706d16e 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/marshallers/xml/ScalaXmlSupportSpec.scala @@ -4,28 +4,102 @@ package akka.http.scaladsl.marshallers.xml +import java.io.File + +import akka.http.scaladsl.TestUtils +import scala.concurrent.duration._ +import org.xml.sax.SAXParseException + +import scala.concurrent.{ Future, Await } import scala.xml.NodeSeq -import org.scalatest.{ Matchers, WordSpec } +import org.scalatest.{ Inside, FreeSpec, Matchers } import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.unmarshalling.{ Unmarshaller, Unmarshal } import akka.http.scaladsl.model._ import HttpCharsets._ import MediaTypes._ -class ScalaXmlSupportSpec extends WordSpec with Matchers with ScalatestRouteTest { +class ScalaXmlSupportSpec extends FreeSpec with Matchers with ScalatestRouteTest with Inside { import ScalaXmlSupport._ - "ScalaXmlSupport" should { - "NodeSeqMarshaller should marshal xml snippets to `text/xml` content in UTF-8" in { + "NodeSeqMarshaller should" - { + "marshal xml snippets to `text/xml` content in UTF-8" in { marshal(Ha“llo) shouldEqual HttpEntity(ContentType(`text/xml`, `UTF-8`), "Ha“llo") } - "nodeSeqUnmarshaller should unmarshal `text/xml` content in UTF-8 to NodeSeqs" in { + "unmarshal `text/xml` content in UTF-8 to NodeSeqs" in { Unmarshal(HttpEntity(`text/xml`, "Hällö")).to[NodeSeq].map(_.text) should evaluateTo("Hällö") } - "nodeSeqUnmarshaller should reject `application/octet-stream`" in { + "reject `application/octet-stream`" in { Unmarshal(HttpEntity(`application/octet-stream`, "Hällö")).to[NodeSeq].map(_.text) should haveFailedWith(Unmarshaller.UnsupportedContentTypeException(nodeSeqContentTypeRanges: _*)) } + + "don't be vulnerable to XXE attacks" - { + "parse XML bodies without loading in a related schema" in { + withTempFile("I shouldn't be there!") { f ⇒ + val xml = s""" + | + | ]>hello&xxe;""".stripMargin + + shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq]) + } + } + "parse XML bodies without loading in a related schema from a parameter" in { + withTempFile("I shouldnt be there!") { generalEntityFile ⇒ + withTempFile { + s""" + |">""".stripMargin + } { parameterEntityFile ⇒ + val xml = s""" + | + | %xpe; + | %pe; + | ]>hello&xxe;""".stripMargin + shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq]) + } + } + } + "gracefully fail when there are too many nested entities" in { + val nested = for (x ← 1 to 30) yield "" + val xml = + s""" + | + | + | ${nested.mkString("\n")} + | ]> + | &laugh30;""".stripMargin + + shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq]) + } + "gracefully fail when an entity expands to be very large" in { + val as = "a" * 50000 + val entities = "&a;" * 50000 + val xml = s""" + | + | ]> + | $entities""".stripMargin + shouldHaveFailedWithSAXParseException(Unmarshal(HttpEntity(`text/xml`, xml)).to[NodeSeq]) + } + } + } + + def shouldHaveFailedWithSAXParseException(result: Future[NodeSeq]) = + inside(Await.result(result.failed, 1.second)) { + case _: SAXParseException ⇒ + } + + def withTempFile[T](content: String)(f: File ⇒ T): T = { + val file = File.createTempFile("xxe", ".txt") + try { + TestUtils.writeAllText(content, file) + f(file) + } finally { + file.delete() + } } } diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala index fb745e77df..7fb637d93e 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FileAndResourceDirectivesSpec.scala @@ -5,7 +5,7 @@ package akka.http.scaladsl.server package directives -import java.io.{ File, FileOutputStream } +import java.io.File import scala.concurrent.duration._ import scala.concurrent.{ ExecutionContext, Future } import scala.util.Properties @@ -15,6 +15,7 @@ import akka.http.scaladsl.model.MediaTypes._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.impl.util._ +import akka.http.scaladsl.TestUtils.writeAllText class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Inside { @@ -356,13 +357,6 @@ class FileAndResourceDirectivesSpec extends RoutingSpec with Inspectors with Ins def prep(s: String) = s.stripMarginWithNewline("\n") - def writeAllText(text: String, file: File): Unit = { - val fos = new FileOutputStream(file) - try { - fos.write(text.getBytes("UTF-8")) - } finally fos.close() - } - def evaluateTo[T](t: T, atMost: Duration = 100.millis)(implicit ec: ExecutionContext): Matcher[Future[T]] = be(t).compose[Future[T]] { fut ⇒ fut.awaitResult(atMost)