* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.spark.sql.catalyst.xml
import{BufferedReader, CharConversionException, FileNotFoundException, InputStream, InputStreamReader, IOException, StringReader}
import java.nio.charset.{Charset, MalformedInputException}
import java.text.NumberFormat
import java.util.Locale
import{XMLEventReader, XMLStreamException}
import javax.xml.validation.Schema
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.control.NonFatal
import scala.xml.SAXException
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.{SparkIllegalArgumentException, SparkUpgradeException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class StaxXmlParser(
schema: StructType,
val options: XmlOptions) extends Logging {
private lazy val timestampFormatter = TimestampFormatter(
legacyFormat = FAST_DATE_FORMAT,
isParsing = true)
private lazy val timestampNTZFormatter = TimestampFormatter(
legacyFormat = FAST_DATE_FORMAT,
isParsing = true,
forTimestampNTZ = true)
private lazy val dateFormatter = DateFormatter(
legacyFormat = FAST_DATE_FORMAT,
isParsing = true)
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
private val caseSensitive = SQLConf.get.caseSensitiveAnalysis
* Parses a single XML string and turns it into either one resulting row or no row (if the
* the record is malformed).
val parse: String => Option[InternalRow] = {
// This is intentionally a val to create a function once and reuse.
if (schema.isEmpty) {
(_: String) => Some(InternalRow.empty)
} else {
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
(input: String) => doParseColumn(input, options.parseMode, xsdSchema)
private def getFieldIndex(schema: StructType, fieldName: String): Option[Int] = {
if (caseSensitive) {
} else {
def parseStream(
inputStream: InputStream,
schema: StructType): Iterator[InternalRow] = {
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
val safeParser = new FailureSafeParser[String](
input => doParseColumn(input, options.parseMode, xsdSchema),
val xmlTokenizer = new XmlTokenizer(inputStream, options)
convertStream(xmlTokenizer) { tokens =>
def parseColumn(xml: String, schema: StructType): InternalRow = {
// The user=specified schema from from_xml, etc will typically not include a
// "corrupted record" column. In PERMISSIVE mode, which puts bad records in
// such a column, this would cause an error. In this mode, if such a column
// is not manually specified, then fall back to DROPMALFORMED, which will return
// null column values where parsing fails.
val parseMode =
if (options.parseMode == PermissiveMode &&
!schema.fields.exists( == options.columnNameOfCorruptRecord)) {
} else {
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
doParseColumn(xml, parseMode, xsdSchema).orNull
def doParseColumn(xml: String,
parseMode: ParseMode,
xsdSchema: Option[Schema]): Option[InternalRow] = {
lazy val xmlRecord = UTF8String.fromString(xml)
try {
xsdSchema.foreach { schema =>
schema.newValidator().validate(new StreamSource(new StringReader(xml)))
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
val result = Some(convertObject(parser, schema, rootAttributes))
} catch {
case e: SparkUpgradeException => throw e
case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException
| _: SAXException) =>
// XML parser currently doesn't support partial results for corrupted records.
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => xmlRecord, cause = () => e)
case e: CharConversionException if options.charset.isEmpty =>
throw BadRecordException(() => xmlRecord, cause = () => {
val msg =
"""XML parser cannot handle a character in its input.
|Specifying encoding as an input option explicitly might help to resolve the issue.
|""".stripMargin + e.getMessage
val wrappedCharException = new CharConversionException(msg)
case PartialResultException(row, cause) =>
throw BadRecordException(
record = () => xmlRecord,
partialResults = () => Array(row),
() => cause)
case PartialResultArrayException(rows, cause) =>
throw BadRecordException(
record = () => xmlRecord,
partialResults = () => rows,
() => cause)
* Parse the current token (and related children) according to a desired schema
private[xml] def convertField(
parser: XMLEventReader,
dataType: DataType,
startElementName: String,
attributes: Array[Attribute] = Array.empty): Any = {
def convertComplicatedType(
dt: DataType,
startElementName: String,
attributes: Array[Attribute]): Any = dt match {
case st: StructType => convertObject(parser, st)
case MapType(StringType, vt, _) => convertMap(parser, vt, attributes)
case ArrayType(st, _) => convertField(parser, st, startElementName)
case _: StringType =>
parser, startElementName, options),
(parser.peek, dataType) match {
case (_: StartElement, dt: DataType) =>
convertComplicatedType(dt, startElementName, attributes)
case (_: EndElement, _: StringType) =>
StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options)
// Empty. It's null if "" is the null value
if (options.nullValue == "") {
} else {
case (_: EndElement, _: DataType) =>
StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options)
case (c: Characters, ArrayType(st, _)) =>
// For `ArrayType`, it needs to return the type of element. The values are merged later.
val value = convertTo(c.getData, st)
StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options)
case (_: Characters, st: StructType) =>
convertObject(parser, st)
case (_: Characters, _: StringType) =>
parser, startElementName, options),
case (c: Characters, _: DataType) if c.isWhiteSpace =>
// When `Characters` is found, we need to look further to decide
// if this is really data or space between other elements.
convertField(parser, dataType, startElementName, attributes)
case (c: Characters, dt: DataType) =>
val value = convertTo(c.getData, dt)
StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options)
case (e: XMLEvent, dt: DataType) =>
throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3240",
messageParameters = Map(
"dt" -> dt.toString,
"e" -> e.toString))
* Parse an object as map.
private def convertMap(
parser: XMLEventReader,
valueType: DataType,
attributes: Array[Attribute]): MapData = {
val kvPairs = ArrayBuffer.empty[(UTF8String, Any)]
attributes.foreach { attr =>
kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart)
-> convertTo(attr.getValue, valueType))
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
val key = StaxXmlParserUtils.getName(e.asStartElement.getName, options)
kvPairs +=
(UTF8String.fromString(key) -> convertField(parser, valueType, key))
case c: Characters if !c.isWhiteSpace =>
// Create a value tag field for it
kvPairs +=
// TODO: We don't support an array value tags in map yet.
(UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType))
case _: EndElement | _: EndDocument =>
shouldStop = true
case _ => // do nothing
* Convert XML attributes to a map with the given schema types.
private def convertAttributes(
attributes: Array[Attribute],
schema: StructType): Map[String, Any] = {
val convertedValuesMap = collection.mutable.Map.empty[String, Any]
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
valuesMap.foreach { case (f, v) =>
val indexOpt = getFieldIndex(schema, f)
indexOpt.foreach { i =>
convertedValuesMap(f) = convertTo(v, schema(i).dataType)
* [[convertObject()]] calls this in order to convert the nested object to a row.
* [[convertObject()]] contains some logic to find out which events are the start
* and end of a nested row and this function converts the events to a row.
private def convertObjectWithAttributes(
parser: XMLEventReader,
schema: StructType,
startElementName: String,
attributes: Array[Attribute] = Array.empty): InternalRow = {
// TODO: This method might have to be removed. Some logics duplicate `convertObject()`
val row = new Array[Any](schema.length)
// Read attributes first.
val attributesMap = convertAttributes(attributes, schema)
// Then, we read elements here.
val fieldsMap = convertField(parser, schema, startElementName) match {
case internalRow: InternalRow =>
Map( _*)
case v if schema.fieldNames.contains(options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
val valuesMap =, null)).toMap
valuesMap + (options.valueTag -> v)
case _ => Map.empty
// Here we merge both to a row.
val valuesMap = fieldsMap ++ attributesMap
valuesMap.foreach { case (f, v) =>
val indexOpt = getFieldIndex(schema, f)
indexOpt.foreach { row(_) = v }
if (valuesMap.isEmpty) {
// Return an empty row with all nested elements by the schema set to null.
new GenericInternalRow(Array.fill[Any](schema.length)(null))
} else {
new GenericInternalRow(row)
* Parse an object from the event stream into a new InternalRow representing the schema.
* Fields in the xml that are not defined in the requested schema will be dropped.
private def convertObject(
parser: XMLEventReader,
schema: StructType,
rootAttributes: Array[Attribute] = Array.empty): InternalRow = {
val row = new Array[Any](schema.length)
// If there are attributes, then we process them first.
convertAttributes(rootAttributes, schema).toSeq.foreach {
case (f, v) =>
getFieldIndex(schema, f).foreach { row(_) = v }
val wildcardColName = options.wildcardColName
val hasWildcard = schema.exists( == wildcardColName)
var badRecordException: Option[Throwable] = None
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement => try {
val attributes = e.getAttributes.asScala.toArray
val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options)
getFieldIndex(schema, field) match {
case Some(index) => schema(index).dataType match {
case st: StructType =>
row(index) = convertObjectWithAttributes(parser, st, field, attributes)
case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
val newValue = dt match {
case st: StructType =>
convertObjectWithAttributes(parser, st, field, attributes)
case dt: DataType =>
convertField(parser, dt, field)
row(index) = values :+ newValue
case dt: DataType =>
row(index) = convertField(parser, dt, field, attributes)
case None =>
if (hasWildcard) {
// Special case: there's an 'any' wildcard element that matches anything else
// as a string (or array of strings, to parse multiple ones)
val newValue = convertField(parser, StringType, field)
val anyIndex = schema.fieldIndex(wildcardColName)
schema(wildcardColName).dataType match {
case StringType =>
row(anyIndex) = newValue
case ArrayType(StringType, _) =>
val values = Option(row(anyIndex))
row(anyIndex) = values :+ newValue
} else {
StaxXmlParserUtils.skipNextEndElement(parser, field, options)
} catch {
case e: SparkUpgradeException => throw e
case NonFatal(e) =>
// TODO: we don't support partial results now
badRecordException = badRecordException.orElse(Some(e))
case c: Characters if !c.isWhiteSpace =>
addOrUpdate(row, schema, options.valueTag, c.getData)
case _: EndElement | _: EndDocument =>
shouldStop = true
case _ => // do nothing
// TODO: find a more efficient way to convert ArrayBuffer to GenericArrayData
val newRow = new Array[Any](schema.length)
var i = 0
while (i < schema.length) {
if (row(i).isInstanceOf[ArrayBuffer[_]]) {
newRow(i) = new GenericArrayData(row(i).asInstanceOf[ArrayBuffer[Any]])
} else {
newRow(i) = row(i)
i += 1;
if (badRecordException.isEmpty) {
new GenericInternalRow(newRow)
} else {
throw PartialResultException(new GenericInternalRow(newRow),
* Casts given string datum to specified type.
* For string types, this is simply the datum.
* For other nullable types, returns null if it is null or equals to the value specified
* in `nullValue` option.
* @param datum string value
* @param castType SparkSQL type
private def castTo(
datum: String,
castType: DataType): Any = {
if (datum == options.nullValue || datum == null) {
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => Try(datum.toFloat)
case _: DoubleType => Try(datum.toDouble)
case _: BooleanType => parseXmlBoolean(datum)
case dt: DecimalType =>
Decimal(decimalParser(datum), dt.precision, dt.scale)
case _: TimestampType => parseXmlTimestamp(datum, options)
case _: TimestampNTZType => timestampNTZFormatter.parseWithoutTimeZone(datum, false)
case _: DateType => parseXmlDate(datum, options)
case _: StringType => UTF8String.fromString(datum)
case _ => throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3244",
messageParameters = Map("castType" -> "castType.typeName"))
private def parseXmlBoolean(s: String): Boolean = {
s.toLowerCase(Locale.ROOT) match {
case "true" | "1" => true
case "false" | "0" => false
case _ => throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3245",
messageParameters = Map("s" -> s))
private def parseXmlDate(value: String, options: XmlOptions): Int = {
private def parseXmlTimestamp(value: String, options: XmlOptions): Long = {
// TODO: This function unnecessarily does type dispatch. Should merge it with `castTo`.
private def convertTo(
datum: String,
dataType: DataType): Any = {
val value = if (datum != null && options.ignoreSurroundingSpaces) {
} else {
if (value == options.nullValue || value == null) {
} else {
dataType match {
case NullType => castTo(value, StringType)
case LongType => signSafeToLong(value)
case DoubleType => signSafeToDouble(value)
case BooleanType => castTo(value, BooleanType)
case StringType => castTo(value, StringType)
case DateType => castTo(value, DateType)
case TimestampType => castTo(value, TimestampType)
case TimestampNTZType => castTo(value, TimestampNTZType)
case FloatType => signSafeToFloat(value)
case ByteType => castTo(value, ByteType)
case ShortType => castTo(value, ShortType)
case IntegerType => signSafeToInt(value)
case dt: DecimalType => castTo(value, dt)
case _ => throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3246",
messageParameters = Map("dataType" -> dataType.toString))
private def signSafeToLong(value: String): Long = {
if (value.startsWith("+")) {
val data = value.substring(1)
castTo(data, LongType).asInstanceOf[Long]
} else if (value.startsWith("-")) {
val data = value.substring(1)
-castTo(data, LongType).asInstanceOf[Long]
} else {
val data = value
castTo(data, LongType).asInstanceOf[Long]
private def signSafeToDouble(value: String): Double = {
if (value.startsWith("+")) {
val data = value.substring(1)
castTo(data, DoubleType).asInstanceOf[Double]
} else if (value.startsWith("-")) {
val data = value.substring(1)
-castTo(data, DoubleType).asInstanceOf[Double]
} else {
val data = value
castTo(data, DoubleType).asInstanceOf[Double]
private def signSafeToInt(value: String): Int = {
if (value.startsWith("+")) {
val data = value.substring(1)
castTo(data, IntegerType).asInstanceOf[Int]
} else if (value.startsWith("-")) {
val data = value.substring(1)
-castTo(data, IntegerType).asInstanceOf[Int]
} else {
val data = value
castTo(data, IntegerType).asInstanceOf[Int]
private def signSafeToFloat(value: String): Float = {
if (value.startsWith("+")) {
val data = value.substring(1)
castTo(data, FloatType).asInstanceOf[Float]
} else if (value.startsWith("-")) {
val data = value.substring(1)
-castTo(data, FloatType).asInstanceOf[Float]
} else {
val data = value
castTo(data, FloatType).asInstanceOf[Float]
private def addOrUpdate(
row: Array[Any],
schema: StructType,
name: String,
data: String,
addToTail: Boolean = true): InternalRow = {
schema.getFieldIndex(name) match {
case Some(index) =>
schema(index).dataType match {
case ArrayType(elementType, _) =>
val value = convertTo(data, elementType)
val values = Option(row(index))
row(index) = if (addToTail) {
values :+ value
} else {
value +: values
case dataType =>
row(index) = convertTo(data, dataType)
case None => // do nothing
new GenericInternalRow(row)
* XMLRecordReader class to read through a given xml document to output xml blocks as records
* as specified by the start tag and end tag.
* This implementation is ultimately loosely based on LineRecordReader in Hadoop.
class XmlTokenizer(
inputStream: InputStream,
options: XmlOptions) extends Logging {
private var reader = new BufferedReader(
new InputStreamReader(inputStream, Charset.forName(options.charset)))
private var currentStartTag: String = _
private var buffer = new StringBuilder()
private val startTag = s"<${options.rowTag}>"
private val endTag = s"</${options.rowTag}>"
private val commentStart = s"<!--"
private val commentEnd = s"-->"
private val cdataStart = s"<![CDATA["
private val cdataEnd = s"]]>"
* Finds the start of the next record.
* It treats data from `startTag` and `endTag` as a record.
* @param key the current key that will be written
* @param value the object that will be written
* @return whether it reads successfully
def next(): Option[String] = {
var nextString: Option[String] = None
try {
if (readUntilStartElement()) {
// Don't check whether the end element was found. Even if not, return everything
// that was read, which will invariably cause a parse error later
nextString = Some(buffer.toString())
buffer = new StringBuilder()
} catch {
case e: FileNotFoundException if options.ignoreMissingFiles =>
"Skipping the rest of" +
" the content in the missing file during schema inference",
case NonFatal(e) =>
ExceptionUtils.getRootCause(e) match {
case _: RuntimeException | _: IOException if options.ignoreCorruptFiles =>
"Skipping the rest of" +
" the content in the corrupted file during schema inference",
case e: Throwable =>
reader = null
throw e
} finally {
if (nextString.isEmpty && reader != null) {
reader = null
private def readUntilMatch(end: String): Boolean = {
var i = 0
while (true) {
val cOrEOF =
if (cOrEOF == -1) {
// End of file.
return false
val c = cOrEOF.toChar
if (c == end(i)) {
i += 1
if (i >= end.length) {
// Found the end string.
return true
} else {
i = 0
// Unreachable.
private def readUntilStartElement(): Boolean = {
currentStartTag = startTag
var i = 0
var commentIdx = 0
var cdataIdx = 0
while (true) {
val cOrEOF =
if (cOrEOF == -1) { // || (i == 0 && getFilePosition() > end)) {
// End of file or end of split.
return false
val c = cOrEOF.toChar
if (c == commentStart(commentIdx)) {
if (commentIdx >= commentStart.length - 1) {
// If a comment beigns we must ignore all character until its end
commentIdx = 0
} else {
commentIdx += 1
} else {
commentIdx = 0
if (c == cdataStart(cdataIdx)) {
if (cdataIdx >= cdataStart.length - 1) {
// If a CDATA beigns we must ignore all character until its end
cdataIdx = 0
} else {
cdataIdx += 1
} else {
cdataIdx = 0
if (c == startTag(i)) {
if (i >= startTag.length - 1) {
// Found start tag.
return true
// else in start tag
i += 1
} else {
// if doesn't match the closing angle bracket, check if followed by attributes
if (i == (startTag.length - 1) && Character.isWhitespace(c)) {
// Found start tag with attributes. Remember to write with following whitespace
// char, not angle bracket
currentStartTag = startTag.dropRight(1) + c
return true
// else not in start tag
i = 0
// Unreachable.
private def readUntilEndElement(startTagClosed: Boolean): Boolean = {
// Index into the start or end tag that has matched so far
var si = 0
var ei = 0
// Index into the start of a comment tag that matched so far
var commentIdx = 0
// Index into the start of a CDATA tag that matched so far
var cdataIdx = 0
// How many other start tags enclose the one that's started already?
var depth = 0
// Previously read character
var prevC = '\u0000'
// The current start tag already found may or may not have terminated with
// a '>' as it may have attributes we read here. If not, we search for
// a self-close tag, but only until a non-self-closing end to the start
// tag is found
var canSelfClose = !startTagClosed
while (true) {
val cOrEOF =
if (cOrEOF == -1) {
// End of file (ignore end of split).
return false
val c = cOrEOF.toChar
if (c == commentStart(commentIdx)) {
if (commentIdx >= commentStart.length - 1) {
// If a comment beigns we must ignore everything until its end
buffer.setLength(buffer.length - commentStart.length)
commentIdx = 0
} else {
commentIdx += 1
} else {
commentIdx = 0
if (c == '>' && prevC != '/') {
canSelfClose = false
// Still matching a start tag?
if (c == startTag(si)) {
// Still also matching an end tag?
if (c == endTag(ei)) {
// In start tag or end tag.
si += 1
ei += 1
} else {
if (si >= startTag.length - 1) {
// Found start tag.
si = 0
ei = 0
depth += 1
} else {
// In start tag.
si += 1
ei = 0
} else if (c == endTag(ei)) {
if (ei >= endTag.length - 1) {
if (depth == 0) {
// Found closing end tag.
return true
// else found nested end tag.
si = 0
ei = 0
depth -= 1
} else {
// In end tag.
si = 0
ei += 1
} else if (c == '>' && prevC == '/' && canSelfClose) {
if (depth == 0) {
// found a self-closing tag (end tag)
return true
// else found self-closing nested tag (end tag)
si = 0
ei = 0
depth -= 1
} else if (si == (startTag.length - 1) && Character.isWhitespace(c)) {
// found a start tag with attributes
si = 0
ei = 0
depth += 1
} else {
// Not in start tag or end tag.
si = 0
ei = 0
prevC = c
// Unreachable.
object StaxXmlParser {
* Parses a stream that contains CSV strings and turns it into an iterator of tokens.
def tokenizeStream(inputStream: InputStream, options: XmlOptions): Iterator[String] = {
val xmlTokenizer = new XmlTokenizer(inputStream, options)
convertStream(xmlTokenizer)(tokens => tokens)
private def convertStream[T](
xmlTokenizer: XmlTokenizer)(
convert: String => T) = new Iterator[T] {
private var nextRecord =
override def hasNext: Boolean = nextRecord.nonEmpty
override def next(): T = {
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
val curRecord = convert(nextRecord.get)
nextRecord =