blob: 6b8740df07a034d59a5232e30d22b3e185bd4626 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.serde
import java.util.Locale
import scala.collection.JavaConverters._
import scala.math.min
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
import org.apache.comet.expressions._
import org.apache.comet.objectstore.NativeConfig
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType._
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
import org.apache.comet.shims.CometExprShim
/**
* An utility object for query plan and expression serialization.
*/
object QueryPlanSerde extends Logging with CometExprShim {
def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
}
def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
_: DecimalType | _: DateType | _: BooleanType | _: NullType =>
true
case s: StructType if allowComplex =>
s.fields.map(_.dataType).forall(supportedDataType(_, allowComplex))
case a: ArrayType if allowComplex =>
supportedDataType(a.elementType, allowComplex)
case m: MapType if allowComplex =>
supportedDataType(m.keyType, allowComplex) && supportedDataType(m.valueType, allowComplex)
case dt =>
emitWarning(s"unsupported Spark data type: $dt")
false
}
/**
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
* false for it.
*/
def serializeDataType(dt: DataType): Option[ExprOuterClass.DataType] = {
val typeId = dt match {
case _: BooleanType => 0
case _: ByteType => 1
case _: ShortType => 2
case _: IntegerType => 3
case _: LongType => 4
case _: FloatType => 5
case _: DoubleType => 6
case _: StringType => 7
case _: BinaryType => 8
case _: TimestampType => 9
case _: DecimalType => 10
case _: TimestampNTZType => 11
case _: DateType => 12
case _: NullType => 13
case _: ArrayType => 14
case _: MapType => 15
case _: StructType => 16
case dt =>
emitWarning(s"Cannot serialize Spark data type: $dt")
return None
}
val builder = ProtoDataType.newBuilder()
builder.setTypeIdValue(typeId)
// Decimal
val dataType = dt match {
case t: DecimalType =>
val info = DataTypeInfo.newBuilder()
val decimal = DecimalInfo.newBuilder()
decimal.setPrecision(t.precision)
decimal.setScale(t.scale)
info.setDecimal(decimal)
builder.setTypeInfo(info.build()).build()
case a: ArrayType =>
val elementType = serializeDataType(a.elementType)
if (elementType.isEmpty) {
return None
}
val info = DataTypeInfo.newBuilder()
val list = ListInfo.newBuilder()
list.setElementType(elementType.get)
list.setContainsNull(a.containsNull)
info.setList(list)
builder.setTypeInfo(info.build()).build()
case m: MapType =>
val keyType = serializeDataType(m.keyType)
if (keyType.isEmpty) {
return None
}
val valueType = serializeDataType(m.valueType)
if (valueType.isEmpty) {
return None
}
val info = DataTypeInfo.newBuilder()
val map = MapInfo.newBuilder()
map.setKeyType(keyType.get)
map.setValueType(valueType.get)
map.setValueContainsNull(m.valueContainsNull)
info.setMap(map)
builder.setTypeInfo(info.build()).build()
case s: StructType =>
val info = DataTypeInfo.newBuilder()
val struct = StructInfo.newBuilder()
val fieldNames = s.fields.map(_.name).toIterable.asJava
val fieldDatatypes = s.fields.map(f => serializeDataType(f.dataType)).toSeq
val fieldNullable = s.fields.map(f => Boolean.box(f.nullable)).toIterable.asJava
if (fieldDatatypes.exists(_.isEmpty)) {
return None
}
struct.addAllFieldNames(fieldNames)
struct.addAllFieldDatatypes(fieldDatatypes.map(_.get).asJava)
struct.addAllFieldNullable(fieldNullable)
info.setStruct(struct)
builder.setTypeInfo(info.build()).build()
case _ => builder.build()
}
Some(dataType)
}
def windowExprToProto(
windowExpr: WindowExpression,
output: Seq[Attribute],
conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap { expr =>
expr match {
case agg: AggregateExpression =>
agg.aggregateFunction match {
case _: Count =>
Some(agg)
case min: Min =>
if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${min.dataType} is not supported", expr)
None
}
case max: Max =>
if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${max.dataType} is not supported", expr)
None
}
case s: Sum =>
if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
.isInstanceOf[DecimalType]) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${s.dataType} is not supported", expr)
None
}
case _ =>
withInfo(
windowExpr,
s"aggregate ${agg.aggregateFunction}" +
" is not supported for window function",
expr)
None
}
case _ =>
None
}
}.toArray
val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
val modes = aggregateExpressions.map(_.mode).distinct
assert(modes.size == 1 && modes.head == Complete)
(aggExprToProto(aggregateExpressions.head, output, true, conf), None)
} else {
(None, exprToProto(windowExpr.windowFunction, output))
}
if (aggExpr.isEmpty && builtinFunc.isEmpty) {
return None
}
val f = windowExpr.windowSpec.frameSpecification
val (frameType, lowerBound, upperBound) = f match {
case SpecifiedWindowFrame(frameType, lBound, uBound) =>
val frameProto = frameType match {
case RowFrame => OperatorOuterClass.WindowFrameType.Rows
case RangeFrame => OperatorOuterClass.WindowFrameType.Range
}
val lBoundProto = lBound match {
case UnboundedPreceding =>
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
.build()
case CurrentRow =>
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e if frameType == RowFrame =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setPreceding(
OperatorOuterClass.Preceding
.newBuilder()
.setOffset(offset)
.build())
.build()
case _ =>
// TODO add support for numeric and temporal RANGE BETWEEN expressions
// see https://github.com/apache/datafusion-comet/issues/1246
return None
}
val uBoundProto = uBound match {
case UnboundedFollowing =>
OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
.build()
case CurrentRow =>
OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
.build()
case e if frameType == RowFrame =>
val offset = e.eval() match {
case i: Integer => i.toLong
case l: Long => l
case _ => return None
}
OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setFollowing(
OperatorOuterClass.Following
.newBuilder()
.setOffset(offset)
.build())
.build()
case _ =>
// TODO add support for numeric and temporal RANGE BETWEEN expressions
// see https://github.com/apache/datafusion-comet/issues/1246
return None
}
(frameProto, lBoundProto, uBoundProto)
case _ =>
(
OperatorOuterClass.WindowFrameType.Rows,
OperatorOuterClass.LowerWindowFrameBound
.newBuilder()
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
.build(),
OperatorOuterClass.UpperWindowFrameBound
.newBuilder()
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
.build())
}
val frame = OperatorOuterClass.WindowFrame
.newBuilder()
.setFrameType(frameType)
.setLowerBound(lowerBound)
.setUpperBound(upperBound)
.build()
val spec =
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()
if (builtinFunc.isDefined) {
Some(
OperatorOuterClass.WindowExpr
.newBuilder()
.setBuiltInWindowFunction(builtinFunc.get)
.setSpec(spec)
.build())
} else if (aggExpr.isDefined) {
Some(
OperatorOuterClass.WindowExpr
.newBuilder()
.setAggFunc(aggExpr.get)
.setSpec(spec)
.build())
} else {
None
}
}
def aggExprToProto(
aggExpr: AggregateExpression,
inputs: Seq[Attribute],
binding: Boolean,
conf: SQLConf): Option[AggExpr] = {
if (aggExpr.isDistinct) {
// https://github.com/apache/datafusion-comet/issues/1260
withInfo(aggExpr, "distinct aggregates are not supported")
return None
}
val cometExpr: CometAggregateExpressionSerde = aggExpr.aggregateFunction match {
case _: Sum => CometSum
case _: Average => CometAverage
case _: Count => CometCount
case _: Min => CometMin
case _: Max => CometMax
case _: First => CometFirst
case _: Last => CometLast
case _: BitAndAgg => CometBitAndAgg
case _: BitOrAgg => CometBitOrAgg
case _: BitXorAgg => CometBitXOrAgg
case _: CovSample => CometCovSample
case _: CovPopulation => CometCovPopulation
case _: VarianceSamp => CometVarianceSamp
case _: VariancePop => CometVariancePop
case _: StddevSamp => CometStddevSamp
case _: StddevPop => CometStddevPop
case _: Corr => CometCorr
case _: BloomFilterAggregate => CometBloomFilterAggregate
case fn =>
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
emitWarning(msg)
withInfo(aggExpr, msg, fn.children: _*)
return None
}
cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, conf)
}
def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = {
evalMode match {
case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY
case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY
case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI
case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode")
}
}
/**
* Wrap an expression in a cast.
*/
def castToProto(
expr: Expression,
timeZoneId: Option[String],
dt: DataType,
childExpr: Expr,
evalMode: CometEvalMode.Value): Option[Expr] = {
serializeDataType(dt) match {
case Some(dataType) =>
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr)
castBuilder.setDatatype(dataType)
castBuilder.setEvalMode(evalModeToProto(evalMode))
castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get())
castBuilder.setTimezone(timeZoneId.getOrElse("UTC"))
Some(
ExprOuterClass.Expr
.newBuilder()
.setCast(castBuilder)
.build())
case _ =>
withInfo(expr, s"Unsupported datatype in castToProto: $dt")
None
}
}
def handleCast(
expr: Expression,
child: Expression,
inputs: Seq[Attribute],
binding: Boolean,
dt: DataType,
timeZoneId: Option[String],
evalMode: CometEvalMode.Value): Option[Expr] = {
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val castSupport =
CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
def getIncompatMessage(reason: Option[String]): String =
"Comet does not guarantee correct results for cast " +
s"from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalMode" +
reason.map(str => s" ($str)").getOrElse("")
castSupport match {
case Compatible(_) =>
castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
case Incompatible(reason) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
logWarning(getIncompatMessage(reason))
castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
} else {
withInfo(
expr,
s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " +
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
None
}
case Unsupported =>
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalMode")
None
}
} else {
withInfo(expr, child)
None
}
}
/**
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* This method performs a transformation on the plan to handle decimal promotion and then calls
* into the recursive method [[exprToProtoInternal]].
*
* @param expr
* The input expression
* @param inputs
* The input attributes
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProto(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean = true): Option[Expr] = {
val conf = SQLConf.get
val newExpr =
DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled)
exprToProtoInternal(newExpr, inputs, binding)
}
/**
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* @param expr
* The input expression
* @param inputs
* The input attributes
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
SQLConf.get
def convert(handler: CometExpressionSerde): Option[Expr] = {
handler match {
case _: IncompatExpr if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() =>
withInfo(
expr,
s"$expr is not fully compatible with Spark. To enable it anyway, set " +
s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true. ${CometConf.COMPAT_GUIDE}.")
None
case _ =>
handler.convert(expr, inputs, binding)
}
}
expr match {
case a @ Alias(_, _) =>
val r = exprToProtoInternal(a.child, inputs, binding)
if (r.isEmpty) {
withInfo(expr, a.child)
}
r
case cast @ Cast(_: Literal, dataType, _, _) =>
// This can happen after promoting decimal precisions
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs, binding)
case UnaryExpression(child) if expr.prettyName == "trycast" =>
val timeZoneId = SQLConf.get.sessionLocalTimeZone
handleCast(
expr,
child,
inputs,
binding,
expr.dataType,
Some(timeZoneId),
CometEvalMode.TRY)
case c @ Cast(child, dt, timeZoneId, _) =>
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
createMathExpression(
expr,
left,
right,
inputs,
binding,
add.dataType,
add.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setAdd(mathExpr))
case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
withInfo(add, s"Unsupported datatype ${left.dataType}")
None
case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) =>
createMathExpression(
expr,
left,
right,
inputs,
binding,
sub.dataType,
sub.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setSubtract(mathExpr))
case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
withInfo(sub, s"Unsupported datatype ${left.dataType}")
None
case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) =>
createMathExpression(
expr,
left,
right,
inputs,
binding,
mul.dataType,
mul.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setMultiply(mathExpr))
case mul @ Multiply(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(mul, s"Unsupported datatype ${left.dataType}")
}
None
case div @ Divide(left, right, _) if supportedDataType(left.dataType) =>
// Datafusion now throws an exception for dividing by zero
// See https://github.com/apache/arrow-datafusion/pull/6792
// For now, use NullIf to swap zeros with nulls.
val rightExpr = nullIfWhenPrimitive(right)
createMathExpression(
expr,
left,
rightExpr,
inputs,
binding,
div.dataType,
div.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setDivide(mathExpr))
case div @ Divide(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(div, s"Unsupported datatype ${left.dataType}")
}
None
case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) =>
val rightExpr = nullIfWhenPrimitive(right)
val dataType = (left.dataType, right.dataType) match {
case (l: DecimalType, r: DecimalType) =>
// copy from IntegralDivide.resultDecimalType
val intDig = l.precision - l.scale + r.scale
DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0)
case _ => left.dataType
}
val divideExpr = createMathExpression(
expr,
left,
rightExpr,
inputs,
binding,
dataType,
div.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))
if (divideExpr.isDefined) {
val childExpr = if (dataType.isInstanceOf[DecimalType]) {
// check overflow for decimal type
val builder = ExprOuterClass.CheckOverflow.newBuilder()
builder.setChild(divideExpr.get)
builder.setFailOnError(div.evalMode == EvalMode.ANSI)
builder.setDatatype(serializeDataType(dataType).get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setCheckOverflow(builder)
.build())
} else {
divideExpr
}
// cast result to long
castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY)
} else {
None
}
case div @ IntegralDivide(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(div, s"Unsupported datatype ${left.dataType}")
}
None
case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) =>
val rightExpr = nullIfWhenPrimitive(right)
createMathExpression(
expr,
left,
rightExpr,
inputs,
binding,
rem.dataType,
rem.evalMode == EvalMode.ANSI,
(builder, mathExpr) => builder.setRemainder(mathExpr))
case rem @ Remainder(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(rem, s"Unsupported datatype ${left.dataType}")
}
None
case EqualTo(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setEq(binaryExpr))
case Not(EqualTo(left, right)) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setNeq(binaryExpr))
case EqualNullSafe(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setEqNullSafe(binaryExpr))
case Not(EqualNullSafe(left, right)) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr))
case GreaterThan(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setGt(binaryExpr))
case GreaterThanOrEqual(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setGtEq(binaryExpr))
case LessThan(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setLt(binaryExpr))
case LessThanOrEqual(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
case Literal(value, dataType)
if supportedDataType(dataType, allowComplex = value == null) =>
val exprBuilder = ExprOuterClass.Literal.newBuilder()
if (value == null) {
exprBuilder.setIsNull(true)
} else {
exprBuilder.setIsNull(false)
dataType match {
case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int])
case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long])
case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float])
case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double])
case _: StringType =>
exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString)
case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long])
case _: TimestampNTZType => exprBuilder.setLongVal(value.asInstanceOf[Long])
case _: DecimalType =>
// Pass decimal literal as bytes.
val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue
exprBuilder.setDecimalVal(
com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
case _: BinaryType =>
val byteStr =
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
exprBuilder.setBytesVal(byteStr)
case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
case dt =>
logWarning(s"Unexpected datatype '$dt' for literal value '$value'")
}
}
val dt = serializeDataType(dataType)
if (dt.isDefined) {
exprBuilder.setDatatype(dt.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setLiteral(exprBuilder)
.build())
} else {
withInfo(expr, s"Unsupported datatype $dataType")
None
}
case Literal(_, dataType) if !supportedDataType(dataType) =>
withInfo(expr, s"Unsupported datatype $dataType")
None
case Substring(str, Literal(pos, _), Literal(len, _)) =>
val strExpr = exprToProtoInternal(str, inputs, binding)
if (strExpr.isDefined) {
val builder = ExprOuterClass.Substring.newBuilder()
builder.setChild(strExpr.get)
builder.setStart(pos.asInstanceOf[Int])
builder.setLen(len.asInstanceOf[Int])
Some(
ExprOuterClass.Expr
.newBuilder()
.setSubstring(builder)
.build())
} else {
withInfo(expr, str)
None
}
case StructsToJson(options, child, timezoneId) =>
if (options.nonEmpty) {
withInfo(expr, "StructsToJson with options is not supported")
None
} else {
def isSupportedType(dt: DataType): Boolean = {
dt match {
case StructType(fields) =>
fields.forall(f => isSupportedType(f.dataType))
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType |
DataTypes.DoubleType | DataTypes.StringType =>
true
case DataTypes.DateType | DataTypes.TimestampType =>
// TODO implement these types with tests for formatting options and timezone
false
case _: MapType | _: ArrayType =>
// Spark supports map and array in StructsToJson but this is not yet
// implemented in Comet
false
case _ => false
}
}
val isSupported = child.dataType match {
case s: StructType =>
s.fields.forall(f => isSupportedType(f.dataType))
case _: MapType | _: ArrayType =>
// Spark supports map and array in StructsToJson but this is not yet
// implemented in Comet
false
case _ =>
false
}
if (isSupported) {
exprToProtoInternal(child, inputs, binding) match {
case Some(p) =>
val toJson = ExprOuterClass.ToJson
.newBuilder()
.setChild(p)
.setTimezone(timezoneId.getOrElse("UTC"))
.setIgnoreNullFields(true)
.build()
Some(
ExprOuterClass.Expr
.newBuilder()
.setToJson(toJson)
.build())
case _ =>
withInfo(expr, child)
None
}
} else {
withInfo(expr, "Unsupported data type", child)
None
}
}
case Like(left, right, escapeChar) =>
if (escapeChar == '\\') {
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setLike(binaryExpr))
} else {
// TODO custom escape char
withInfo(expr, s"custom escape character $escapeChar not supported in LIKE")
None
}
case RLike(left, right) =>
// we currently only support scalar regex patterns
right match {
case Literal(pattern, DataTypes.StringType) =>
if (!RegExp.isSupportedPattern(pattern.toString) &&
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
withInfo(
expr,
s"Regexp pattern $pattern is not compatible with Spark. " +
s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " +
"to allow it anyway.")
return None
}
case _ =>
withInfo(expr, "Only scalar regexp patterns are supported")
return None
}
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setRlike(binaryExpr))
case StartsWith(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setStartsWith(binaryExpr))
case EndsWith(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setEndsWith(binaryExpr))
case Contains(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setContains(binaryExpr))
case StringSpace(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setStringSpace(unaryExpr))
case Hour(child, timeZoneId) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val builder = ExprOuterClass.Hour.newBuilder()
builder.setChild(childExpr.get)
val timeZone = timeZoneId.getOrElse("UTC")
builder.setTimezone(timeZone)
Some(
ExprOuterClass.Expr
.newBuilder()
.setHour(builder)
.build())
} else {
withInfo(expr, child)
None
}
case Minute(child, timeZoneId) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val builder = ExprOuterClass.Minute.newBuilder()
builder.setChild(childExpr.get)
val timeZone = timeZoneId.getOrElse("UTC")
builder.setTimezone(timeZone)
Some(
ExprOuterClass.Expr
.newBuilder()
.setMinute(builder)
.build())
} else {
withInfo(expr, child)
None
}
case DateAdd(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)
case DateSub(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)
case TruncDate(child, format) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val formatExpr = exprToProtoInternal(format, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, childExpr, formatExpr)
optExprWithInfo(optExpr, expr, child, format)
case TruncTimestamp(format, child, timeZoneId) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val formatExpr = exprToProtoInternal(format, inputs, binding)
if (childExpr.isDefined && formatExpr.isDefined) {
val builder = ExprOuterClass.TruncTimestamp.newBuilder()
builder.setChild(childExpr.get)
builder.setFormat(formatExpr.get)
val timeZone = timeZoneId.getOrElse("UTC")
builder.setTimezone(timeZone)
Some(
ExprOuterClass.Expr
.newBuilder()
.setTruncTimestamp(builder)
.build())
} else {
withInfo(expr, child, format)
None
}
case Second(child, timeZoneId) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val builder = ExprOuterClass.Second.newBuilder()
builder.setChild(childExpr.get)
val timeZone = timeZoneId.getOrElse("UTC")
builder.setTimezone(timeZone)
Some(
ExprOuterClass.Expr
.newBuilder()
.setSecond(builder)
.build())
} else {
withInfo(expr, child)
None
}
case Year(child) =>
val periodType = exprToProtoInternal(Literal("year"), inputs, binding)
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("datepart", Seq(periodType, childExpr): _*)
.map(e => {
Expr
.newBuilder()
.setCast(
ExprOuterClass.Cast
.newBuilder()
.setChild(e)
.setDatatype(serializeDataType(IntegerType).get)
.setEvalMode(ExprOuterClass.EvalMode.LEGACY)
.setAllowIncompat(false)
.build())
.build()
})
optExprWithInfo(optExpr, expr, child)
case IsNull(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNull(unaryExpr))
case IsNotNull(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
case IsNaN(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr)
optExprWithInfo(optExpr, expr, child)
case SortOrder(child, direction, nullOrdering, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
sortOrderBuilder.setChild(childExpr.get)
direction match {
case Ascending => sortOrderBuilder.setDirectionValue(0)
case Descending => sortOrderBuilder.setDirectionValue(1)
}
nullOrdering match {
case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
}
Some(
ExprOuterClass.Expr
.newBuilder()
.setSortOrder(sortOrderBuilder)
.build())
} else {
withInfo(expr, child)
None
}
case And(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setAnd(binaryExpr))
case Or(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setOr(binaryExpr))
case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
// `UnaryExpression` includes `PromotePrecision` for Spark 3.3
// `PromotePrecision` is just a wrapper, don't need to serialize it.
exprToProtoInternal(child, inputs, binding)
case CheckOverflow(child, dt, nullOnOverflow) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val builder = ExprOuterClass.CheckOverflow.newBuilder()
builder.setChild(childExpr.get)
builder.setFailOnError(!nullOnOverflow)
// `dataType` must be decimal type
val dataType = serializeDataType(dt)
builder.setDatatype(dataType.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setCheckOverflow(builder)
.build())
} else {
withInfo(expr, child)
None
}
case attr: AttributeReference =>
val dataType = serializeDataType(attr.dataType)
if (dataType.isDefined) {
if (binding) {
// Spark may produce unresolvable attributes in some cases,
// for example https://github.com/apache/datafusion-comet/issues/925.
// So, we allow the binding to fail.
val boundRef: Any = BindReferences
.bindReference(attr, inputs, allowFailures = true)
if (boundRef.isInstanceOf[AttributeReference]) {
withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", ")}")
return None
}
val boundExpr = ExprOuterClass.BoundReference
.newBuilder()
.setIndex(boundRef.asInstanceOf[BoundReference].ordinal)
.setDatatype(dataType.get)
.build()
Some(
ExprOuterClass.Expr
.newBuilder()
.setBound(boundExpr)
.build())
} else {
val unboundRef = ExprOuterClass.UnboundReference
.newBuilder()
.setName(attr.name)
.setDatatype(dataType.get)
.build()
Some(
ExprOuterClass.Expr
.newBuilder()
.setUnbound(unboundRef)
.build())
}
} else {
withInfo(attr, s"unsupported datatype: ${attr.dataType}")
None
}
// abs implementation is not correct
// https://github.com/apache/datafusion-comet/issues/666
// case Abs(child, failOnErr) =>
// val childExpr = exprToProtoInternal(child, inputs)
// if (childExpr.isDefined) {
// val evalModeStr =
// if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY
// val absBuilder = ExprOuterClass.Abs.newBuilder()
// absBuilder.setChild(childExpr.get)
// absBuilder.setEvalMode(evalModeStr)
// Some(Expr.newBuilder().setAbs(absBuilder).build())
// } else {
// withInfo(expr, child)
// None
// }
case Acos(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("acos", childExpr)
optExprWithInfo(optExpr, expr, child)
case Asin(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("asin", childExpr)
optExprWithInfo(optExpr, expr, child)
case Atan(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("atan", childExpr)
optExprWithInfo(optExpr, expr, child)
case Atan2(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)
case Hex(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("hex", StringType, childExpr)
optExprWithInfo(optExpr, expr, child)
case e: Unhex =>
val unHex = unhexSerde(e)
val childExpr = exprToProtoInternal(unHex._1, inputs, binding)
val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr)
optExprWithInfo(optExpr, expr, unHex._1)
case e @ Ceil(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
child.dataType match {
case t: DecimalType if t.scale == 0 => // zero scale is no-op
childExpr
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
withInfo(e, s"Decimal type $t has negative scale")
None
case _ =>
val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", e.dataType, childExpr)
optExprWithInfo(optExpr, expr, child)
}
case Cos(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("cos", childExpr)
optExprWithInfo(optExpr, expr, child)
case Exp(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("exp", childExpr)
optExprWithInfo(optExpr, expr, child)
case e @ Floor(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
child.dataType match {
case t: DecimalType if t.scale == 0 => // zero scale is no-op
childExpr
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
withInfo(e, s"Decimal type $t has negative scale")
None
case _ =>
val optExpr = scalarFunctionExprToProtoWithReturnType("floor", e.dataType, childExpr)
optExprWithInfo(optExpr, expr, child)
}
// The expression for `log` functions is defined as null on numbers less than or equal
// to 0. This matches Spark and Hive behavior, where non positive values eval to null
// instead of NaN or -Infinity.
case Log(child) =>
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
val optExpr = scalarFunctionExprToProto("ln", childExpr)
optExprWithInfo(optExpr, expr, child)
case Log10(child) =>
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
val optExpr = scalarFunctionExprToProto("log10", childExpr)
optExprWithInfo(optExpr, expr, child)
case Log2(child) =>
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
val optExpr = scalarFunctionExprToProto("log2", childExpr)
optExprWithInfo(optExpr, expr, child)
case Pow(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
val optExpr = scalarFunctionExprToProto("pow", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)
case r: Round =>
// _scale s a constant, copied from Spark's RoundBase because it is a protected val
val scaleV: Any = r.scale.eval(EmptyRow)
val _scale: Int = scaleV.asInstanceOf[Int]
lazy val childExpr = exprToProtoInternal(r.child, inputs, binding)
r.child.dataType match {
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
withInfo(r, "Decimal type has negative scale")
None
case _ if scaleV == null =>
exprToProtoInternal(Literal(null), inputs, binding)
case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
case _: FloatType | DoubleType =>
// We cannot properly match with the Spark behavior for floating-point numbers.
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
// double to string internally in order to create its own internal representation.
// The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
// rounding algorithm. E.g. -5.81855622136895E8 is actually
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
// 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
// -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
// toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
// be rounded up to 6.13171162472835E18 that still represents the same double number.
// I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
// of 6.1317116247283999E18.
withInfo(r, "Comet does not support Spark's BigDecimal rounding")
None
case _ =>
// `scale` must be Int64 type in DataFusion
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr)
optExprWithInfo(optExpr, expr, r.child)
}
case Signum(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("signum", childExpr)
optExprWithInfo(optExpr, expr, child)
case Sin(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("sin", childExpr)
optExprWithInfo(optExpr, expr, child)
case Sqrt(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("sqrt", childExpr)
optExprWithInfo(optExpr, expr, child)
case Tan(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("tan", childExpr)
optExprWithInfo(optExpr, expr, child)
case _: Ascii =>
CometAscii.convert(expr, inputs, binding)
case Expm1(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("expm1", childExpr)
optExprWithInfo(optExpr, expr, child)
case s: StringDecode =>
// Right child is the encoding expression.
s.right match {
case Literal(str, DataTypes.StringType)
if str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
// decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls
// for invalid strings.
// Left child is the binary expression.
castToProto(
expr,
None,
DataTypes.StringType,
exprToProtoInternal(s.left, inputs, binding).get,
CometEvalMode.TRY)
case _ =>
withInfo(expr, "Comet only supports decoding with 'utf-8'.")
None
}
case RegExpReplace(subject, pattern, replacement, startPosition) =>
if (!RegExp.isSupportedPattern(pattern.toString) &&
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
withInfo(
expr,
s"Regexp pattern $pattern is not compatible with Spark. " +
s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " +
"to allow it anyway.")
return None
}
startPosition match {
case Literal(value, DataTypes.IntegerType) if value == 1 =>
val subjectExpr = exprToProtoInternal(subject, inputs, binding)
val patternExpr = exprToProtoInternal(pattern, inputs, binding)
val replacementExpr = exprToProtoInternal(replacement, inputs, binding)
// DataFusion's regexp_replace stops at the first match. We need to add the 'g' flag
// to apply the regex globally to match Spark behavior.
val flagsExpr = exprToProtoInternal(Literal("g"), inputs, binding)
val optExpr = scalarFunctionExprToProto(
"regexp_replace",
subjectExpr,
patternExpr,
replacementExpr,
flagsExpr)
optExprWithInfo(optExpr, expr, subject, pattern, replacement, startPosition)
case _ =>
withInfo(expr, "Comet only supports regexp_replace with an offset of 1 (no offset).")
None
}
case _: BitLength =>
CometBitLength.convert(expr, inputs, binding)
case If(predicate, trueValue, falseValue) =>
val predicateExpr = exprToProtoInternal(predicate, inputs, binding)
val trueExpr = exprToProtoInternal(trueValue, inputs, binding)
val falseExpr = exprToProtoInternal(falseValue, inputs, binding)
if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) {
val builder = ExprOuterClass.IfExpr.newBuilder()
builder.setIfExpr(predicateExpr.get)
builder.setTrueExpr(trueExpr.get)
builder.setFalseExpr(falseExpr.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setIf(builder)
.build())
} else {
withInfo(expr, predicate, trueValue, falseValue)
None
}
case CaseWhen(branches, elseValue) =>
var allBranches: Seq[Expression] = Seq()
val whenSeq = branches.map(elements => {
allBranches = allBranches :+ elements._1
exprToProtoInternal(elements._1, inputs, binding)
})
val thenSeq = branches.map(elements => {
allBranches = allBranches :+ elements._2
exprToProtoInternal(elements._2, inputs, binding)
})
assert(whenSeq.length == thenSeq.length)
if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
val builder = ExprOuterClass.CaseWhen.newBuilder()
builder.addAllWhen(whenSeq.map(_.get).asJava)
builder.addAllThen(thenSeq.map(_.get).asJava)
if (elseValue.isDefined) {
val elseValueExpr =
exprToProtoInternal(elseValue.get, inputs, binding)
if (elseValueExpr.isDefined) {
builder.setElseExpr(elseValueExpr.get)
} else {
withInfo(expr, elseValue.get)
return None
}
}
Some(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(builder)
.build())
} else {
withInfo(expr, allBranches: _*)
None
}
case _: ConcatWs =>
CometConcatWs.convert(expr, inputs, binding)
case _: Chr =>
CometChr.convert(expr, inputs, binding)
case _: InitCap =>
CometInitCap.convert(expr, inputs, binding)
case _: Length =>
CometLength.convert(expr, inputs, binding)
case Md5(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProto("md5", childExpr)
optExprWithInfo(optExpr, expr, child)
case OctetLength(child) =>
val castExpr = Cast(child, StringType)
val childExpr = exprToProtoInternal(castExpr, inputs, binding)
val optExpr = scalarFunctionExprToProto("octet_length", childExpr)
optExprWithInfo(optExpr, expr, castExpr)
case Reverse(child) =>
val castExpr = Cast(child, StringType)
val childExpr = exprToProtoInternal(castExpr, inputs, binding)
val optExpr = scalarFunctionExprToProto("reverse", childExpr)
optExprWithInfo(optExpr, expr, castExpr)
case _: StringInstr =>
CometStringInstr.convert(expr, inputs, binding)
case _: StringRepeat =>
CometStringRepeat.convert(expr, inputs, binding)
case _: StringReplace =>
CometStringReplace.convert(expr, inputs, binding)
case _: StringTranslate =>
CometStringTranslate.convert(expr, inputs, binding)
case _: StringTrim =>
CometTrim.convert(expr, inputs, binding)
case _: StringTrimLeft =>
CometStringTrimLeft.convert(expr, inputs, binding)
case _: StringTrimRight =>
CometStringTrimRight.convert(expr, inputs, binding)
case _: StringTrimBoth =>
CometStringTrimBoth.convert(expr, inputs, binding)
case _: Upper =>
CometUpper.convert(expr, inputs, binding)
case _: Lower =>
CometLower.convert(expr, inputs, binding)
case BitwiseAnd(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
case BitwiseNot(child) =>
val childProto = exprToProto(child, inputs, binding)
val bitNotScalarExpr =
scalarFunctionExprToProto("bit_not", childProto)
optExprWithInfo(bitNotScalarExpr, expr, expr.children: _*)
case BitwiseOr(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
case BitwiseXor(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
case BitwiseCount(child) =>
val childProto = exprToProto(child, inputs, binding)
val bitCountScalarExpr =
scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto)
optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
case ShiftRight(left, right) =>
// DataFusion bitwise shift right expression requires
// same data type between left and right side
val rightExpression = if (left.dataType == LongType) {
Cast(right, LongType)
} else {
right
}
createBinaryExpr(
expr,
left,
rightExpression,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
case ShiftLeft(left, right) =>
// DataFusion bitwise shift right expression requires
// same data type between left and right side
val rightExpression = if (left.dataType == LongType) {
Cast(right, LongType)
} else {
right
}
createBinaryExpr(
expr,
left,
rightExpression,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
case In(value, list) =>
in(expr, value, list, inputs, binding, negate = false)
case InSet(value, hset) =>
val valueDataType = value.dataType
val list = hset.map { setVal =>
Literal(setVal, valueDataType)
}.toSeq
// Change `InSet` to `In` expression
// We do Spark `InSet` optimization in native (DataFusion) side.
in(expr, value, list, inputs, binding, negate = false)
case Not(In(value, list)) =>
in(expr, value, list, inputs, binding, negate = true)
case Not(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setNot(unaryExpr))
case UnaryMinus(child, failOnError) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val builder = ExprOuterClass.UnaryMinus.newBuilder()
builder.setChild(childExpr.get)
builder.setFailOnError(failOnError)
Some(
ExprOuterClass.Expr
.newBuilder()
.setUnaryMinus(builder)
.build())
} else {
withInfo(expr, child)
None
}
case a @ Coalesce(_) =>
val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding))
scalarFunctionExprToProto("coalesce", exprChildren: _*)
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types.
// See https://github.com/apache/spark/pull/38151
case s: StaticInvoke
// classOf gets ther runtime class of T, which lets us compare directly
// Otherwise isInstanceOf[Class[T]] will always evaluate to true for Class[_]
if s.staticObject == classOf[CharVarcharCodegenUtils] &&
s.dataType.isInstanceOf[StringType] &&
s.functionName == "readSidePadding" &&
s.arguments.size == 2 &&
s.propagateNull &&
!s.returnNullable &&
s.isDeterministic =>
val argsExpr = Seq(
exprToProtoInternal(Cast(s.arguments(0), StringType), inputs, binding),
exprToProtoInternal(s.arguments(1), inputs, binding))
if (argsExpr.forall(_.isDefined)) {
scalarFunctionExprToProto("read_side_padding", argsExpr: _*)
} else {
withInfo(expr, s.arguments: _*)
None
}
// read-side padding in Spark 3.5.2+ is represented by rpad function
case StringRPad(srcStr, size, chars) =>
chars match {
case Literal(str, DataTypes.StringType) if str.toString == " " =>
val arg0 = exprToProtoInternal(srcStr, inputs, binding)
val arg1 = exprToProtoInternal(size, inputs, binding)
if (arg0.isDefined && arg1.isDefined) {
scalarFunctionExprToProto("rpad", arg0, arg1)
} else {
withInfo(expr, "rpad unsupported arguments", srcStr, size)
None
}
case _ =>
withInfo(expr, "rpad only supports padding with spaces")
None
}
case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
val dataType = serializeDataType(expr.dataType)
if (dataType.isEmpty) {
withInfo(expr, s"Unsupported datatype ${expr.dataType}")
return None
}
val ex = exprToProtoInternal(expr, inputs, binding)
ex.map { child =>
val builder = ExprOuterClass.NormalizeNaNAndZero
.newBuilder()
.setChild(child)
.setDatatype(dataType.get)
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
}
case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) =>
val dataType = serializeDataType(s.dataType)
if (dataType.isEmpty) {
withInfo(s, s"Scalar subquery returns unsupported datatype ${s.dataType}")
return None
}
val builder = ExprOuterClass.Subquery
.newBuilder()
.setId(s.exprId.id)
.setDatatype(dataType.get)
Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build())
case UnscaledValue(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("unscaled_value", LongType, childExpr)
optExprWithInfo(optExpr, expr, child)
case MakeDecimal(child, precision, scale, true) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr = scalarFunctionExprToProtoWithReturnType(
"make_decimal",
DecimalType(precision, scale),
childExpr)
optExprWithInfo(optExpr, expr, child)
case b @ BloomFilterMightContain(_, _) =>
val bloomFilter = b.left
val value = b.right
val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs, binding)
val valueExpr = exprToProtoInternal(value, inputs, binding)
if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
builder.setBloomFilter(bloomFilterExpr.get)
builder.setValue(valueExpr.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setBloomFilterMightContain(builder)
.build())
} else {
withInfo(expr, bloomFilter, value)
None
}
case _: Murmur3Hash => CometMurmur3Hash.convert(expr, inputs, binding)
case _: XxHash64 => CometXxHash64.convert(expr, inputs, binding)
case Sha2(left, numBits) =>
if (!numBits.foldable) {
withInfo(expr, "non literal numBits is not supported")
return None
}
// it's possible for spark to dynamically compute the number of bits from input
// expression, however DataFusion does not support that yet.
val childExpr = exprToProtoInternal(left, inputs, binding)
val bits = numBits.eval().asInstanceOf[Int]
val algorithm = bits match {
case 224 => "sha224"
case 256 | 0 => "sha256"
case 384 => "sha384"
case 512 => "sha512"
case _ =>
null
}
if (algorithm == null) {
exprToProtoInternal(Literal(null, StringType), inputs, binding)
} else {
scalarFunctionExprToProtoWithReturnType(algorithm, StringType, childExpr)
}
case struct @ CreateNamedStruct(_) =>
if (struct.names.length != struct.names.distinct.length) {
withInfo(expr, "CreateNamedStruct with duplicate field names are not supported")
return None
}
val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding))
if (valExprs.forall(_.isDefined)) {
val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
structBuilder.addAllValues(valExprs.map(_.get).asJava)
structBuilder.addAllNames(struct.names.map(_.toString).asJava)
Some(
ExprOuterClass.Expr
.newBuilder()
.setCreateNamedStruct(structBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for CreateNamedStruct", struct.valExprs: _*)
None
}
case GetStructField(child, ordinal, _) =>
exprToProtoInternal(child, inputs, binding).map { childExpr =>
val getStructFieldBuilder = ExprOuterClass.GetStructField
.newBuilder()
.setChild(childExpr)
.setOrdinal(ordinal)
ExprOuterClass.Expr
.newBuilder()
.setGetStructField(getStructFieldBuilder)
.build()
}
case CreateArray(children, _) =>
val childExprs = children.map(exprToProtoInternal(_, inputs, binding))
if (childExprs.forall(_.isDefined)) {
scalarFunctionExprToProto("make_array", childExprs: _*)
} else {
withInfo(expr, "unsupported arguments for CreateArray", children: _*)
None
}
case GetArrayItem(child, ordinal, failOnError) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
if (childExpr.isDefined && ordinalExpr.isDefined) {
val listExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(false)
.setFailOnError(failOnError)
Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(listExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal)
None
}
case expr if expr.prettyName == "array_insert" => convert(CometArrayInsert)
case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))
if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == defaultValue.isDefined) {
val arrayExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(true)
.setFailOnError(failOnError)
defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(arrayExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
None
}
case GetArrayStructFields(child, _, ordinal, _, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinal)
Some(
ExprOuterClass.Expr
.newBuilder()
.setGetArrayStructFields(arrayStructFieldsBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}
case _: ArrayRemove => convert(CometArrayRemove)
case _: ArrayContains => convert(CometArrayContains)
case _: ArrayAppend => convert(CometArrayAppend)
case _: ArrayIntersect => convert(CometArrayIntersect)
case _: ArrayJoin => convert(CometArrayJoin)
case _: ArraysOverlap => convert(CometArraysOverlap)
case _: ArrayRepeat => convert(CometArrayRepeat)
case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] =>
convert(CometArrayCompact)
case _: ArrayExcept =>
convert(CometArrayExcept)
case mk: MapKeys =>
val childExpr = exprToProtoInternal(mk.child, inputs, binding)
scalarFunctionExprToProto("map_keys", childExpr)
case mv: MapValues =>
val childExpr = exprToProtoInternal(mv.child, inputs, binding)
scalarFunctionExprToProto("map_values", childExpr)
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
}
}
/**
* Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and then
* invokes the supplied function to wrap this UnaryExpr in a top-level Expr.
*
* @param child
* Spark expression
* @param inputs
* Inputs to the expression
* @param f
* Function that accepts an Expr.Builder and a UnaryExpr and builds the specific top-level
* Expr
* @return
* Some(Expr) or None if not supported
*/
def createUnaryExpr(
expr: Expression,
child: Expression,
inputs: Seq[Attribute],
binding: Boolean,
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder)
: Option[ExprOuterClass.Expr] = {
val childExpr = exprToProtoInternal(child, inputs, binding) // TODO review
if (childExpr.isDefined) {
// create the generic UnaryExpr message
val inner = ExprOuterClass.UnaryExpr
.newBuilder()
.setChild(childExpr.get)
.build()
// call the user-supplied function to wrap UnaryExpr in a top-level Expr
// such as Expr.IsNull or Expr.IsNotNull
Some(
f(
ExprOuterClass.Expr
.newBuilder(),
inner).build())
} else {
withInfo(expr, child)
None
}
}
def createBinaryExpr(
expr: Expression,
left: Expression,
right: Expression,
inputs: Seq[Attribute],
binding: Boolean,
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
: Option[ExprOuterClass.Expr] = {
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
if (leftExpr.isDefined && rightExpr.isDefined) {
// create the generic BinaryExpr message
val inner = ExprOuterClass.BinaryExpr
.newBuilder()
.setLeft(leftExpr.get)
.setRight(rightExpr.get)
.build()
// call the user-supplied function to wrap BinaryExpr in a top-level Expr
// such as Expr.And or Expr.Or
Some(
f(
ExprOuterClass.Expr
.newBuilder(),
inner).build())
} else {
withInfo(expr, left, right)
None
}
}
private def createMathExpression(
expr: Expression,
left: Expression,
right: Expression,
inputs: Seq[Attribute],
binding: Boolean,
dataType: DataType,
failOnError: Boolean,
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder)
: Option[ExprOuterClass.Expr] = {
val leftExpr = exprToProtoInternal(left, inputs, binding)
val rightExpr = exprToProtoInternal(right, inputs, binding)
if (leftExpr.isDefined && rightExpr.isDefined) {
// create the generic MathExpr message
val builder = ExprOuterClass.MathExpr.newBuilder()
builder.setLeft(leftExpr.get)
builder.setRight(rightExpr.get)
builder.setFailOnError(failOnError)
serializeDataType(dataType).foreach { t =>
builder.setReturnType(t)
}
val inner = builder.build()
// call the user-supplied function to wrap MathExpr in a top-level Expr
// such as Expr.Add or Expr.Divide
Some(
f(
ExprOuterClass.Expr
.newBuilder(),
inner).build())
} else {
withInfo(expr, left, right)
None
}
}
def in(
expr: Expression,
value: Expression,
list: Seq[Expression],
inputs: Seq[Attribute],
binding: Boolean,
negate: Boolean): Option[Expr] = {
val valueExpr = exprToProtoInternal(value, inputs, binding)
val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
val builder = ExprOuterClass.In.newBuilder()
builder.setInValue(valueExpr.get)
builder.addAllLists(listExprs.map(_.get).asJava)
builder.setNegated(negate)
Some(
ExprOuterClass.Expr
.newBuilder()
.setIn(builder)
.build())
} else {
val allExprs = list ++ Seq(value)
withInfo(expr, allExprs: _*)
None
}
}
def scalarFunctionExprToProtoWithReturnType(
funcName: String,
returnType: DataType,
args: Option[Expr]*): Option[Expr] = {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
builder.setFunc(funcName)
serializeDataType(returnType).flatMap { t =>
builder.setReturnType(t)
scalarFunctionExprToProto0(builder, args: _*)
}
}
def scalarFunctionExprToProto(funcName: String, args: Option[Expr]*): Option[Expr] = {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
builder.setFunc(funcName)
scalarFunctionExprToProto0(builder, args: _*)
}
private def scalarFunctionExprToProto0(
builder: ScalarFunc.Builder,
args: Option[Expr]*): Option[Expr] = {
args.foreach {
case Some(a) => builder.addArgs(a)
case _ =>
return None
}
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
}
private def isPrimitive(expression: Expression): Boolean = expression.dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: TimestampType | _: DateType | _: BooleanType | _: DecimalType =>
true
case _ => false
}
private def nullIfWhenPrimitive(expression: Expression): Expression =
if (isPrimitive(expression)) {
val zero = Literal.default(expression.dataType)
expression match {
case _: Literal if expression != zero => expression
case _ =>
If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression)
}
} else {
expression
}
private def nullIfNegative(expression: Expression): Expression = {
val zero = Literal.default(expression.dataType)
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
}
/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
true
case TimestampNTZType => true
case _ => false
}
/**
* Convert a Spark plan operator to a protobuf Comet operator.
*
* @param op
* Spark plan operator
* @param childOp
* previously converted protobuf Comet operators, which will be consumed by the Spark plan
* operator as its children
* @return
* The converted Comet native operator for the input `op`, or `None` if the `op` cannot be
* converted to a native operator.
*/
def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
val conf = op.conf
val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
childOp.foreach(result.addChildren)
op match {
// Fully native scan for V1
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION =>
val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
nativeScanBuilder.setSource(op.simpleStringWithNodeId())
val scanTypes = op.output.flatten { attr =>
serializeDataType(attr.dataType)
}
if (scanTypes.length == op.output.length) {
nativeScanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
result.clearChildren()
if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED)) {
// TODO remove flatMap and add error handling for unsupported data filters
val dataFilters = scan.dataFilters.flatMap(exprToProto(_, scan.output))
nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
}
val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema)
if (possibleDefaultValues.exists(_ != null)) {
// Our schema has default values. Serialize two lists, one with the default values
// and another with the indexes in the schema so the native side can map missing
// columns to these default values.
val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex
.filter { case (expr, _) => expr != null }
.map { case (expr, index) =>
// ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these
// expressions and they should now just be literals.
(Literal(expr), index.toLong.asInstanceOf[java.lang.Long])
}
.unzip
nativeScanBuilder.addAllDefaultValues(
defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava)
nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
}
// TODO: modify CometNativeScan to generate the file partitions without instantiating RDD.
var firstPartition: Option[PartitionedFile] = None
scan.inputRDD match {
case rdd: DataSourceRDD =>
val partitions = rdd.partitions
partitions.foreach(p => {
val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions
inputPartitions.foreach(partition => {
if (firstPartition.isEmpty) {
firstPartition = partition.asInstanceOf[FilePartition].files.headOption
}
partition2Proto(
partition.asInstanceOf[FilePartition],
nativeScanBuilder,
scan.relation.partitionSchema)
})
})
case rdd: FileScanRDD =>
rdd.filePartitions.foreach(partition => {
if (firstPartition.isEmpty) {
firstPartition = partition.files.headOption
}
partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema)
})
case _ =>
}
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
val dataSchemaIndexes = scan.requiredSchema.fields.map(field => {
scan.relation.dataSchema.fieldIndex(field.name)
})
val partitionSchemaIndexes = Array
.range(
scan.relation.dataSchema.fields.length,
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)
val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx =>
idx.toLong.asInstanceOf[java.lang.Long])
nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone"))
nativeScanBuilder.setCaseSensitive(conf.getConf[Boolean](SQLConf.CASE_SENSITIVE))
// Collect S3/cloud storage configurations
val hadoopConf = scan.relation.sparkSession.sessionState
.newHadoopConfWithOptions(scan.relation.options)
firstPartition.foreach { partitionFile =>
val objectStoreOptions =
NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri)
objectStoreOptions.foreach { case (key, value) =>
nativeScanBuilder.putObjectStoreOptions(key, value)
}
}
Some(result.setNativeScan(nativeScanBuilder).build())
} else {
// There are unsupported scan type
val msg =
s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
emitWarning(msg)
withInfo(op, msg)
None
}
case ProjectExec(projectList, child) if CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) =>
val exprs = projectList.map(exprToProto(_, child.output))
if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
val projectBuilder = OperatorOuterClass.Projection
.newBuilder()
.addAllProjectList(exprs.map(_.get).asJava)
Some(result.setProjection(projectBuilder).build())
} else {
withInfo(op, projectList: _*)
None
}
case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
val cond = exprToProto(condition, child.output)
if (cond.isDefined && childOp.nonEmpty) {
// We need to determine whether to use DataFusion's FilterExec or Comet's
// FilterExec. The difference is that DataFusion's implementation will sometimes pass
// batches through whereas the Comet implementation guarantees that a copy is always
// made, which is critical when using `native_comet` scans due to buffer re-use
// TODO this could be optimized more to stop walking the tree on hitting
// certain operators such as join or aggregate which will copy batches
def containsNativeCometScan(plan: SparkPlan): Boolean = {
plan match {
case w: CometScanWrapper => containsNativeCometScan(w.originalPlan)
case scan: CometScanExec => scan.scanImpl == CometConf.SCAN_NATIVE_COMET
case _: CometNativeScanExec => false
case _ => plan.children.exists(containsNativeCometScan)
}
}
val filterBuilder = OperatorOuterClass.Filter
.newBuilder()
.setPredicate(cond.get)
.setUseDatafusionFilter(!containsNativeCometScan(op))
Some(result.setFilter(filterBuilder).build())
} else {
withInfo(op, condition, child)
None
}
case SortExec(sortOrder, _, child, _) if CometConf.COMET_EXEC_SORT_ENABLED.get(conf) =>
if (!supportedSortType(op, sortOrder)) {
return None
}
val sortOrders = sortOrder.map(exprToProto(_, child.output))
if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
val sortBuilder = OperatorOuterClass.Sort
.newBuilder()
.addAllSortOrders(sortOrders.map(_.get).asJava)
Some(result.setSort(sortBuilder).build())
} else {
withInfo(op, "sort order not supported", sortOrder: _*)
None
}
case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
if (childOp.nonEmpty) {
// LocalLimit doesn't use offset, but it shares same operator serde class.
// Just set it to zero.
val limitBuilder = OperatorOuterClass.Limit
.newBuilder()
.setLimit(limit)
.setOffset(0)
Some(result.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
}
case globalLimitExec: GlobalLimitExec
if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) =>
// TODO: We don't support negative limit for now.
if (childOp.nonEmpty && globalLimitExec.limit >= 0) {
val limitBuilder = OperatorOuterClass.Limit.newBuilder()
// TODO: Spark 3.3 might have negative limit (-1) for Offset usage.
// When we upgrade to Spark 3.3., we need to address it here.
limitBuilder.setLimit(globalLimitExec.limit)
Some(result.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
}
case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) =>
var allProjExprs: Seq[Expression] = Seq()
val projExprs = projections.flatMap(_.map(e => {
allProjExprs = allProjExprs :+ e
exprToProto(e, child.output)
}))
if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
val expandBuilder = OperatorOuterClass.Expand
.newBuilder()
.addAllProjectList(projExprs.map(_.get).asJava)
.setNumExprPerProject(projections.head.size)
Some(result.setExpand(expandBuilder).build())
} else {
withInfo(op, allProjExprs: _*)
None
}
case WindowExec(windowExpression, partitionSpec, orderSpec, child)
if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
val output = child.output
val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr =>
expr match {
case alias: Alias =>
alias.child match {
case winExpr: WindowExpression =>
Some(winExpr)
case _ =>
None
}
case _ =>
None
}
}.toArray
if (winExprs.length != windowExpression.length) {
withInfo(op, "Unsupported window expression(s)")
return None
}
if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}
val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))
val sortOrders = orderSpec.map(exprToProto(_, child.output))
if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined)
&& sortOrders.forall(_.isDefined)) {
val windowBuilder = OperatorOuterClass.Window.newBuilder()
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
Some(result.setWindow(windowBuilder).build())
} else {
None
}
case aggregate: BaseAggregateExec
if (aggregate.isInstanceOf[HashAggregateExec] ||
aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
val groupingExpressions = aggregate.groupingExpressions
val aggregateExpressions = aggregate.aggregateExpressions
val aggregateAttributes = aggregate.aggregateAttributes
val resultExpressions = aggregate.resultExpressions
val child = aggregate.child
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
withInfo(op, "No group by or aggregation")
return None
}
// Aggregate expressions with filter are not supported yet.
if (aggregateExpressions.exists(_.filter.isDefined)) {
withInfo(op, "Aggregate expression with filter is not supported")
return None
}
if (groupingExpressions.exists(expr =>
expr.dataType match {
case _: MapType => true
case _ => false
})) {
withInfo(op, "Grouping on map types is not supported")
return None
}
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))
if (groupingExprs.exists(_.isEmpty)) {
withInfo(op, "Not all grouping expressions are supported")
return None
}
// In some of the cases, the aggregateExpressions could be empty.
// For example, if the aggregate functions only have group by or if the aggregate
// functions only have distinct aggregate functions:
//
// SELECT COUNT(distinct col2), col1 FROM test group by col1
// +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
// +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
// +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
// If the aggregateExpressions is empty, we only want to build groupingExpressions,
// and skip processing of aggregateExpressions.
if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
val msg = s"Unsupported result expressions found in: ${resultExpressions}"
emitWarning(msg)
withInfo(op, msg, resultExpressions: _*)
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
val modes = aggregateExpressions.map(_.mode).distinct
if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
withInfo(op, "All aggregate expressions do not have the same mode")
return None
}
val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ =>
withInfo(op, s"Unsupported aggregation mode ${modes.head}")
return None
}
// In final mode, the aggregate expressions are bound to the output of the
// child and partial aggregate expressions buffer attributes produced by partial
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
// we don't have to do this because we don't use the merging expression.
val binding = mode != CometAggregateMode.Final
// `output` is only used when `binding` is true (i.e., non-Final)
val output = child.output
val aggExprs =
aggregateExpressions.map(aggExprToProto(_, output, binding, op.conf))
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
val msg = s"Unsupported result expressions found in: ${resultExpressions}"
emitWarning(msg)
withInfo(op, msg, resultExpressions: _*)
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
val allChildren: Seq[Expression] =
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes
withInfo(op, allChildren: _*)
None
}
}
case join: HashJoin =>
// `HashJoin` has only two implementations in Spark, but we check the type of the join to
// make sure we are handling the correct join type.
if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
join.isInstanceOf[ShuffledHashJoinExec]) &&
!(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
join.isInstanceOf[BroadcastHashJoinExec])) {
withInfo(join, s"Invalid hash join type ${join.nodeName}")
return None
}
if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
withInfo(join, "BuildRight with LeftAnti is not supported")
return None
}
val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
withInfo(join, cond)
return None
}
condProto.get
}
val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(join, s"Unsupported join type ${join.joinType}")
return None
}
val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
if (leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.HashJoin
.newBuilder()
.setJoinType(joinType)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
.setBuildSide(
if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight)
condition.foreach(joinBuilder.setCondition)
Some(result.setHashJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
None
}
case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
keys.map(SortOrder(_, Ascending))
}
def getKeyOrdering(
keys: Seq[Expression],
childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
val requiredOrdering = requiredOrders(keys)
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
}
} else {
requiredOrdering
}
}
if (join.condition.isDefined &&
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED
.get(conf)) {
withInfo(
join,
s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled",
join.condition.get)
return None
}
val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
withInfo(join, cond)
return None
}
condProto.get
}
val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(op, s"Unsupported join type ${join.joinType}")
return None
}
// Checks if the join keys are supported by DataFusion SortMergeJoin.
val errorMsgs = join.leftKeys.flatMap { key =>
if (!supportedSortMergeJoinEqualType(key.dataType)) {
Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
} else {
None
}
}
if (errorMsgs.nonEmpty) {
withInfo(op, errorMsgs.flatten.mkString("\n"))
return None
}
val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering)
.map(exprToProto(_, join.left.output))
if (sortOptions.forall(_.isDefined) &&
leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.SortMergeJoin
.newBuilder()
.setJoinType(joinType)
.addAllSortOptions(sortOptions.map(_.get).asJava)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
condition.map(joinBuilder.setCondition)
Some(result.setSortMergeJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
None
}
case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
withInfo(join, "SortMergeJoin is not enabled")
None
case op if isCometSink(op) =>
val supportedTypes =
op.output.forall(a => supportedDataType(a.dataType, allowComplex = true))
if (!supportedTypes) {
return None
}
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
val source = op.simpleStringWithNodeId()
if (source.isEmpty) {
scanBuilder.setSource(op.getClass.getSimpleName)
} else {
scanBuilder.setSource(source)
}
val scanTypes = op.output.flatten { attr =>
serializeDataType(attr.dataType)
}
if (scanTypes.length == op.output.length) {
scanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
result.clearChildren()
Some(result.setScan(scanBuilder).build())
} else {
// There are unsupported scan type
val msg =
s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
emitWarning(msg)
withInfo(op, msg)
None
}
case op =>
// Emit warning if:
// 1. it is not Spark shuffle operator, which is handled separately
// 2. it is not a Comet operator
if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) {
val msg = s"unsupported Spark operator: ${op.nodeName}"
emitWarning(msg)
withInfo(op, msg)
}
None
}
}
/**
* Whether the input Spark operator `op` can be considered as a Comet sink, i.e., the start of
* native execution. If it is true, we'll wrap `op` with `CometScanWrapper` or
* `CometSinkPlaceHolder` later in `CometSparkSessionExtensions` after `operator2proto` is
* called.
*/
private def isCometSink(op: SparkPlan): Boolean = {
op match {
case s if isCometScan(s) => true
case _: CometSparkToColumnarExec => true
case _: CometSinkPlaceHolder => true
case _: CoalesceExec => true
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) =>
true
case _: BroadcastExchangeExec => true
case _: WindowExec => true
case _ => false
}
}
/**
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
* which supports struct/array.
*/
def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
val inputs = s.child.output
val partitioning = s.outputPartitioning
var msg = ""
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
// columnar shuffle supports the same data types (including complex types) both for
// partition keys and for other columns
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $expressions"
}
supported
case SinglePartition =>
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
case RoundRobinPartitioning(_) =>
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
case RangePartitioning(orderings, _) =>
val supported =
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $orderings"
}
supported
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}
if (!supported) {
emitWarning(msg)
(false, msg)
} else {
(true, null)
}
}
/**
* Whether the given Spark partitioning is supported by Comet native shuffle.
*/
def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
/**
* Determine which data types are supported as hash-partition keys in native shuffle.
*
* Hash Partition Key determines how data should be collocated for operations like
* `groupByKey`, `reduceByKey` or `join`.
*/
def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match {
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
_: TimestampNTZType | _: DecimalType | _: DateType =>
true
case _ =>
false
}
val inputs = s.child.output
val partitioning = s.outputPartitioning
val conf = SQLConf.get
var msg = ""
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
// native shuffle currently does not support complex types as partition keys
// due to lack of hashing support for those types
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
if (!supported) {
msg = s"unsupported Spark partitioning: $expressions"
}
supported
case SinglePartition =>
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
case RangePartitioning(ordering, _) =>
val supported = ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
if (!supported) {
msg = s"unsupported Spark partitioning: $ordering"
}
supported
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}
if (!supported) {
emitWarning(msg)
(false, msg)
} else {
(true, null)
}
}
/**
* Determine which data types are supported in a shuffle.
*/
def supportedShuffleDataType(dt: DataType): Boolean = dt match {
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
_: TimestampNTZType | _: DecimalType | _: DateType =>
true
case StructType(fields) =>
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
// Java Arrow stream reader cannot work on duplicate field name
fields.map(f => f.name).distinct.length == fields.length
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
case ArrayType(elementType, _) =>
supportedShuffleDataType(elementType)
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
case MapType(_, MapType(_, _, _), _) => false
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
case MapType(_, StructType(_), _) => false
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
case MapType(_, ArrayType(_, _), _) => false
case MapType(keyType, valueType, _) =>
supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
case _ =>
false
}
// Utility method. Adds explain info if the result of calling exprToProto is None
def optExprWithInfo(
optExpr: Option[Expr],
expr: Expression,
childExpr: Expression*): Option[Expr] = {
optExpr match {
case None =>
withInfo(expr, childExpr: _*)
None
case o => o
}
}
// TODO: Remove this constraint when we upgrade to new arrow-rs including
// https://github.com/apache/arrow-rs/pull/6225
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
def canRank(dt: DataType): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: TimestampType | _: DecimalType | _: DateType =>
true
case _: BinaryType | _: StringType => true
case _ => false
}
}
if (sortOrder.length == 1) {
val canSort = sortOrder.head.dataType match {
case _: BooleanType => true
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: TimestampType | _: TimestampNTZType | _: DecimalType |
_: DateType =>
true
case _: BinaryType | _: StringType => true
case ArrayType(elementType, _) => canRank(elementType)
case _ => false
}
if (!canSort) {
withInfo(op, s"Sort on single column of type ${sortOrder.head.dataType} is not supported")
false
} else {
true
}
} else {
true
}
}
private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
op: SparkPlan): Boolean = {
if (partitionSpec.length != orderSpec.length) {
return false
}
val partitionColumnNames = partitionSpec.collect {
case a: AttributeReference => a.name
case other =>
withInfo(op, s"Unsupported partition expression: ${other.getClass.getSimpleName}")
return false
}
val orderColumnNames = orderSpec.collect { case s: SortOrder =>
s.child match {
case a: AttributeReference => a.name
case other =>
withInfo(op, s"Unsupported sort expression: ${other.getClass.getSimpleName}")
return false
}
}
if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, orderCol) =>
partCol != orderCol
}) {
withInfo(op, "Partitioning and sorting specifications must be the same.")
return false
}
true
}
private def schema2Proto(
fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = {
val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
fields.map(field => {
fieldBuilder.setName(field.name)
fieldBuilder.setDataType(serializeDataType(field.dataType).get)
fieldBuilder.setNullable(field.nullable)
fieldBuilder.build()
})
}
private def partition2Proto(
partition: FilePartition,
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
partitionSchema: StructType): Unit = {
val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
partition.files.foreach(file => {
// Process the partition values
val partitionValues = file.partitionValues
assert(partitionValues.numFields == partitionSchema.length)
val partitionVals =
partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) =>
val attr = partitionSchema(i)
val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty)
// In `CometScanRule`, we have already checked that all partition values are
// supported. So, we can safely use `get` here.
assert(
valueProto.isDefined,
s"Unsupported partition value: $value, type: ${attr.dataType}")
valueProto.get
}
val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
partitionVals.foreach(fileBuilder.addPartitionValues)
fileBuilder
.setFilePath(file.filePath.toString)
.setStart(file.start)
.setLength(file.length)
.setFileSize(file.fileSize)
partitionBuilder.addPartitionedFile(fileBuilder.build())
})
nativeScanBuilder.addFilePartitions(partitionBuilder.build())
}
}
/**
* Trait for providing serialization logic for expressions.
*/
trait CometExpressionSerde {
/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
*
* @param expr
* The Spark expression.
* @param inputs
* The input attributes.
* @param binding
* Whether the attributes are bound (this is only relevant in aggregate expressions).
* @return
* Protocol buffer representation, or None if the expression could not be converted. In this
* case it is expected that the input expression will have been tagged with reasons why it
* could not be converted.
*/
def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr]
}
/**
* Trait for providing serialization logic for aggregate expressions.
*/
trait CometAggregateExpressionSerde {
/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
*
* @param aggExpr
* The aggregate expression.
* @param expr
* The aggregate function.
* @param inputs
* The input attributes.
* @param binding
* Whether the attributes are bound (this is only relevant in aggregate expressions).
* @param conf
* SQLConf
* @return
* Protocol buffer representation, or None if the expression could not be converted. In this
* case it is expected that the input expression will have been tagged with reasons why it
* could not be converted.
*/
def convert(
aggExpr: AggregateExpression,
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean,
conf: SQLConf): Option[ExprOuterClass.AggExpr]
}
/** Marker trait for an expression that is not guaranteed to be 100% compatible with Spark */
trait IncompatExpr {}