blob: f7f7097173bb49cc3e96ab173d278544c9490682 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.variant
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.types.variant.{Variant, VariantBuilder, VariantSizeLimitException, VariantUtil}
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
/**
* A utility class for constructing variant expressions.
*/
object VariantExpressionEvalUtils {
def parseJson(input: UTF8String, failOnError: Boolean = true): VariantVal = {
def parseJsonFailure(exception: Throwable): VariantVal = {
if (failOnError) {
throw exception
} else {
null
}
}
try {
val v = VariantBuilder.parseJson(input.toString)
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
parseJsonFailure(QueryExecutionErrors
.variantSizeLimitError(VariantUtil.SIZE_LIMIT, "parse_json"))
case NonFatal(e) =>
parseJsonFailure(QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(
input.toString, e))
}
}
def isVariantNull(input: VariantVal): Boolean = {
if (input == null) {
// This is a SQL NULL, not a Variant NULL
false
} else {
val variantValue = input.getValue
if (variantValue.isEmpty) {
throw QueryExecutionErrors.malformedVariant()
} else {
// Variant NULL is denoted by basic_type == 0 and val_header == 0
variantValue(0) == 0
}
}
}
/** Cast a Spark value from `dataType` into the variant type. */
def castToVariant(input: Any, dataType: DataType): VariantVal = {
val builder = new VariantBuilder
buildVariant(builder, input, dataType)
val v = builder.result()
new VariantVal(v.getValue, v.getMetadata)
}
private def buildVariant(builder: VariantBuilder, input: Any, dataType: DataType): Unit = {
if (input == null) {
builder.appendNull()
return
}
dataType match {
case BooleanType => builder.appendBoolean(input.asInstanceOf[Boolean])
case ByteType => builder.appendLong(input.asInstanceOf[Byte])
case ShortType => builder.appendLong(input.asInstanceOf[Short])
case IntegerType => builder.appendLong(input.asInstanceOf[Int])
case LongType => builder.appendLong(input.asInstanceOf[Long])
case FloatType => builder.appendFloat(input.asInstanceOf[Float])
case DoubleType => builder.appendDouble(input.asInstanceOf[Double])
case _: DecimalType => builder.appendDecimal(input.asInstanceOf[Decimal].toJavaBigDecimal)
case _: StringType => builder.appendString(input.asInstanceOf[UTF8String].toString)
case BinaryType => builder.appendBinary(input.asInstanceOf[Array[Byte]])
case DateType => builder.appendDate(input.asInstanceOf[Int])
case TimestampType => builder.appendTimestamp(input.asInstanceOf[Long])
case TimestampNTZType => builder.appendTimestampNtz(input.asInstanceOf[Long])
case VariantType =>
val v = input.asInstanceOf[VariantVal]
builder.appendVariant(new Variant(v.getValue, v.getMetadata))
case ArrayType(elementType, _) =>
val data = input.asInstanceOf[ArrayData]
val start = builder.getWritePos
val offsets = new java.util.ArrayList[java.lang.Integer](data.numElements())
for (i <- 0 until data.numElements()) {
offsets.add(builder.getWritePos - start)
val element = if (data.isNullAt(i)) null else data.get(i, elementType)
buildVariant(builder, element, elementType)
}
builder.finishWritingArray(start, offsets)
case MapType(StringType, valueType, _) =>
val data = input.asInstanceOf[MapData]
val keys = data.keyArray()
val values = data.valueArray()
val start = builder.getWritePos
val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](data.numElements())
for (i <- 0 until data.numElements()) {
val key = keys.getUTF8String(i).toString
val id = builder.addKey(key)
fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start))
val value = if (values.isNullAt(i)) null else values.get(i, valueType)
buildVariant(builder, value, valueType)
}
builder.finishWritingObject(start, fields)
case StructType(structFields) =>
val data = input.asInstanceOf[InternalRow]
val start = builder.getWritePos
val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](structFields.length)
for (i <- 0 until structFields.length) {
val key = structFields(i).name
val id = builder.addKey(key)
fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start))
val value = if (data.isNullAt(i)) null else data.get(i, structFields(i).dataType)
buildVariant(builder, value, structFields(i).dataType)
}
builder.finishWritingObject(start, fields)
}
}
}