blob: 237d740e043626f0b65c128bf28a9de0c887eab9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.io.CharArrayWriter
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, DropMalformedMode, FailFastMode, FailureSafeParser, GenericArrayData, PermissiveMode}
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
* Converts an XML input string to a [[StructType]] with the specified schema.
* It is assumed that the XML input string constitutes a single record; so the
* [[rowTag]] option will be not applicable.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(xmlStr, schema[, options]) - Returns a struct value with the given `xmlStr` and `schema`.",
examples = """
Examples:
> SELECT _FUNC_('<p><a>1</a><b>0.8</b></p>', 'a INT, b DOUBLE');
{"a":1,"b":0.8}
> SELECT _FUNC_('<p><time>26/08/2015</time></p>', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy'));
{"time":2015-08-26 00:00:00}
> SELECT _FUNC_('<p><teacher>Alice</teacher><student><name>Bob</name><rank>1</rank></student><student><name>Charlie</name><rank>2</rank></student></p>', 'STRUCT<teacher: STRING, student: ARRAY<STRUCT<name: STRING, rank: INT>>>');
{"teacher":"Alice","student":[{"name":"Bob","rank":1},{"name":"Charlie","rank":2}]}
""",
group = "xml_funcs",
since = "4.0.0")
// scalastyle:on line.size.limit
case class XmlToStructs(
schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
with ExpectsInputTypes
with NullIntolerant
with QueryErrorsBase {
def this(child: Expression, schema: Expression, options: Map[String, String]) =
this(
schema = ExprUtils.evalSchemaExpr(schema),
options = options,
child = child,
timeZoneId = None)
override def nullable: Boolean = true
// The XML input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions.
val nullableSchema = schema.asNullable
def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String])
def this(child: Expression, schema: Expression, options: Expression) =
this(
schema = ExprUtils.evalSchemaExpr(schema),
options = ExprUtils.convertToMapData(options),
child = child,
timeZoneId = None)
// This converts parsed rows to the desired output by the given schema.
@transient
lazy val converter = nullableSchema match {
case _: StructType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
case _: ArrayType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
case _: MapType =>
(rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
}
val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
@transient lazy val parser = {
val parsedOptions = new XmlOptions(options, timeZoneId.get, nameOfCorruptRecord)
val mode = parsedOptions.parseMode
if (mode != PermissiveMode && mode != FailFastMode) {
throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode)
}
val (parserSchema, actualSchema) = nullableSchema match {
case s: StructType =>
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
case other =>
(StructType(Array(StructField("value", other))), other)
}
val rowSchema: StructType = schema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
}
val rawParser = new StaxXmlParser(rowSchema, parsedOptions)
val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema)
new FailureSafeParser[String](
input => rawParser.doParseColumn(input, mode, xsdSchema),
mode,
parserSchema,
parsedOptions.columnNameOfCorruptRecord)
}
override def dataType: DataType = nullableSchema
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}
override def nullSafeEval(xml: Any): Any = xml match {
case arr: GenericArrayData =>
new GenericArrayData(arr.array.map(s => converter(parser.parse(s.toString))))
case arr: ArrayData =>
new GenericArrayData(arr.array.map(s => converter(parser.parse(s.toString))))
case _ =>
val str = xml.asInstanceOf[UTF8String].toString
converter(parser.parse(str))
}
override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil
override def sql: String = schema match {
case _: MapType => "entries"
case _ => super.sql
}
override def prettyName: String = "from_xml"
protected def withNewChildInternal(newChild: Expression): XmlToStructs =
copy(child = newChild)
}
/**
* A function infers schema of XML string.
*/
@ExpressionDescription(
usage = "_FUNC_(xml[, options]) - Returns schema in the DDL format of XML string.",
examples = """
Examples:
> SELECT _FUNC_('<p><a>1</a></p>');
STRUCT<a: BIGINT>
> SELECT _FUNC_('<p><a attr="2">1</a><a>3</a></p>', map('excludeAttribute', 'true'));
STRUCT<a: ARRAY<BIGINT>>
""",
since = "4.0.0",
group = "xml_funcs")
case class SchemaOfXml(
child: Expression,
options: Map[String, String])
extends UnaryExpression with CodegenFallback with QueryErrorsBase {
def this(child: Expression) = this(child, Map.empty[String, String])
def this(child: Expression, options: Expression) = this(
child = child,
options = ExprUtils.convertToMapData(options))
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = false
@transient
private lazy val xmlOptions = new XmlOptions(options, "UTC")
@transient
private lazy val xmlFactory = xmlOptions.buildXmlFactory()
@transient
private lazy val xmlInferSchema = {
if (xmlOptions.parseMode == DropMalformedMode) {
throw QueryCompilationErrors.parseModeUnsupportedError("schema_of_xml", xmlOptions.parseMode)
}
new XmlInferSchema(xmlOptions, caseSensitive = SQLConf.get.caseSensitiveAnalysis)
}
@transient
private lazy val xml = child.eval().asInstanceOf[UTF8String]
override def checkInputDataTypes(): TypeCheckResult = {
if (child.foldable && xml != null) {
super.checkInputDataTypes()
} else if (!child.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("xml"),
"inputType" -> toSQLType(child.dataType),
"inputExpr" -> toSQLExpr(child)))
} else {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "xml"))
}
}
override def eval(v: InternalRow): Any = {
val dataType = xmlInferSchema.infer(xml.toString).get match {
case st: StructType =>
xmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil))
case at: ArrayType if at.elementType.isInstanceOf[StructType] =>
xmlInferSchema
.canonicalizeType(at.elementType)
.map(ArrayType(_, containsNull = at.containsNull))
.getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull))
case other: DataType =>
xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType)
}
UTF8String.fromString(dataType.sql)
}
override def prettyName: String = "schema_of_xml"
override protected def withNewChildInternal(newChild: Expression): SchemaOfXml =
copy(child = newChild)
}
/**
* Converts a [[StructType]] to a XML output string.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr[, options]) - Returns a XML string with a given struct value",
examples = """
Examples:
> SELECT _FUNC_(named_struct('a', 1, 'b', 2));
<ROW>
<a>1</a>
<b>2</b>
</ROW>
> SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
<ROW>
<time>26/08/2015</time>
</ROW>
""",
since = "4.0.0",
group = "xml_funcs")
// scalastyle:on line.size.limit
case class StructsToXml(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
with ExpectsInputTypes
with NullIntolerant {
override def nullable: Boolean = true
def this(options: Map[String, String], child: Expression) = this(options, child, None)
// Used in `FunctionRegistry`
def this(child: Expression) = this(Map.empty, child, None)
def this(child: Expression, options: Expression) =
this(
options = ExprUtils.convertToMapData(options),
child = child,
timeZoneId = None)
override def checkInputDataTypes(): TypeCheckResult = {
child.dataType match {
case _: StructType => TypeCheckSuccess
case _ => DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(StructType),
"inputSql" -> toSQLExpr(child),
"inputType" -> toSQLType(child.dataType)
)
)
}
}
@transient
lazy val writer = new CharArrayWriter()
@transient
lazy val inputSchema: StructType = child.dataType.asInstanceOf[StructType]
@transient
lazy val gen = new StaxXmlGenerator(
inputSchema, writer, new XmlOptions(options, timeZoneId.get), false)
// This converts rows to the XML output according to the given schema.
@transient
lazy val converter: Any => UTF8String = {
def getAndReset(): UTF8String = {
gen.flush()
val xmlString = writer.toString
writer.reset()
UTF8String.fromString(xmlString)
}
(row: Any) =>
gen.write(row.asInstanceOf[InternalRow])
getAndReset()
}
override def dataType: DataType = SQLConf.get.defaultStringType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
override def nullSafeEval(value: Any): Any = converter(value)
override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
override def prettyName: String = "to_xml"
override protected def withNewChildInternal(newChild: Expression): StructsToXml =
copy(child = newChild)
}