[GLUTEN-7359][VL] feat: Support columnar partial project for UDF (#7360)
Closes #7359
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index d0106b4..65bf784 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -66,6 +66,7 @@
injector.injectTransform(_ => RewriteSparkPlanRulesManager())
injector.injectTransform(_ => AddFallbackTagRule())
injector.injectTransform(_ => TransformPreOverrides())
+ injector.injectTransform(c => PartialProjectRule.apply(c.session))
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
injector.injectTransform(_ => PushDownFilterToScan)
diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala
index 2778710..c4684e5 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowConvertorRule.scala
@@ -32,7 +32,7 @@
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.csv.CSVTable
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.utils.SparkSchemaUtil
+import org.apache.spark.sql.utils.SparkArrowUtil
import java.nio.charset.StandardCharsets
@@ -87,7 +87,7 @@
options,
columnPruning = session.sessionState.conf.csvColumnPruning,
session.sessionState.conf.sessionLocalTimeZone)
- checkSchema(dataSchema) &&
+ SparkArrowUtil.checkSchema(dataSchema) &&
checkCsvOptions(csvOptions, session.sessionState.conf.sessionLocalTimeZone) &&
dataSchema.nonEmpty
}
@@ -106,13 +106,4 @@
SparkShimLoader.getSparkShims.dateTimestampFormatInReadIsDefaultValue(csvOptions, timeZone)
}
- private def checkSchema(schema: StructType): Boolean = {
- try {
- SparkSchemaUtil.toArrowSchema(schema)
- true
- } catch {
- case _: Exception =>
- false
- }
- }
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
new file mode 100644
index 0000000..d42b7ee
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.GlutenConfig
+import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
+import org.apache.gluten.expression.ExpressionUtils
+import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
+import org.apache.gluten.iterator.Iterators
+import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
+import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.vectorized.ArrowWritableColumnVector
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, MutableProjection, NamedExpression, NaNvl, ScalaUDF, UnsafeProjection}
+import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, WritableColumnVector}
+import org.apache.spark.sql.hive.HiveUdfUtil
+import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, TimestampType, YearMonthIntervalType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+import scala.collection.mutable.ListBuffer
+
+/**
+ * By rule <PartialProhectRule>, the project not offload-able that is changed to
+ * ProjectExecTransformer + ColumnarPartialProjectExec e.g. sum(myudf(a) + b + hash(c)), child is
+ * (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as _SparkPartialProject1),
+ * ProjectExecTransformer(_SparkPartialProject1 + b + hash(c))
+ *
+ * @param original
+ * extract the ScalaUDF from original project list as Alias in UnsafeProjection and
+ * AttributeReference in ColumnarPartialProjectExec output
+ * @param child
+ * child plan
+ */
+case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
+ replacedAliasUdf: Seq[Alias])
+ extends UnaryExecNode
+ with GlutenPlan {
+
+ private val projectAttributes: ListBuffer[Attribute] = ListBuffer()
+ private val projectIndexInChild: ListBuffer[Int] = ListBuffer()
+ private var UDFAttrNotExists = false
+ private var hasUnsupportedDataType = replacedAliasUdf.exists(a => !validateDataType(a.dataType))
+ if (!hasUnsupportedDataType) {
+ getProjectIndexInChildOutput(replacedAliasUdf)
+ }
+
+ @transient override lazy val metrics = Map(
+ "time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of partial project"),
+ "column_to_row_time" -> SQLMetrics.createTimingMetric(
+ sparkContext,
+ "time of velox to Arrow ColumnarBatch or UnsafeRow"),
+ "row_to_column_time" -> SQLMetrics.createTimingMetric(
+ sparkContext,
+ "time of Arrow ColumnarBatch or UnsafeRow to velox")
+ )
+
+ override def output: Seq[Attribute] = child.output ++ replacedAliasUdf.map(_.toAttribute)
+
+ final override def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} doesn't support doExecute")
+ }
+
+ final override protected def otherCopyArgs: Seq[AnyRef] = {
+ replacedAliasUdf :: Nil
+ }
+
+ final override val supportsColumnar: Boolean = true
+
+ private def validateExpression(expr: Expression): Boolean = {
+ expr.deterministic && !expr.isInstanceOf[LambdaFunction] && expr.children
+ .forall(validateExpression)
+ }
+
+ private def validateDataType(dataType: DataType): Boolean = {
+ dataType match {
+ case _: BooleanType => true
+ case _: ByteType => true
+ case _: ShortType => true
+ case _: IntegerType => true
+ case _: LongType => true
+ case _: FloatType => true
+ case _: DoubleType => true
+ case _: StringType => true
+ case _: TimestampType => true
+ case _: DateType => true
+ case _: BinaryType => true
+ case _: DecimalType => true
+ case YearMonthIntervalType.DEFAULT => true
+ case _: NullType => true
+ case _ => false
+ }
+ }
+
+ private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
+ exprs.foreach {
+ case a: AttributeReference =>
+ val index = child.output.indexWhere(s => s.exprId.equals(a.exprId))
+ // Some child operator as HashAggregateTransformer will not have udf child column
+ if (index < 0) {
+ UDFAttrNotExists = true
+ log.debug(s"Expression $a should exist in child output ${child.output}")
+ return
+ } else if (!validateDataType(a.dataType)) {
+ hasUnsupportedDataType = true
+ log.debug(s"Expression $a contains unsupported data type ${a.dataType}")
+ } else if (!projectIndexInChild.contains(index)) {
+ projectAttributes.append(a.toAttribute)
+ projectIndexInChild.append(index)
+ }
+ case p => getProjectIndexInChildOutput(p.children)
+ }
+ }
+
+ override protected def doValidateInternal(): ValidationResult = {
+ if (!GlutenConfig.getConf.enableColumnarPartialProject) {
+ return ValidationResult.failed("Config disable this feature")
+ }
+ if (UDFAttrNotExists) {
+ return ValidationResult.failed("Attribute in the UDF does not exists in its child")
+ }
+ if (hasUnsupportedDataType) {
+ return ValidationResult.failed("Attribute in the UDF contains unsupported type")
+ }
+ if (projectAttributes.size == child.output.size) {
+ return ValidationResult.failed("UDF need all the columns in child output")
+ }
+ if (original.output.isEmpty) {
+ return ValidationResult.failed("Project fallback because output is empty")
+ }
+ if (replacedAliasUdf.isEmpty) {
+ return ValidationResult.failed("No UDF")
+ }
+ if (replacedAliasUdf.size > original.output.size) {
+ // e.g. udf1(col) + udf2(col), it will introduce 2 cols for r2c
+ return ValidationResult.failed("Number of RowToColumn columns is more than ProjectExec")
+ }
+ if (!original.projectList.forall(validateExpression(_))) {
+ return ValidationResult.failed("Contains expression not supported")
+ }
+ if (
+ ExpressionUtils.isComplexExpression(
+ original,
+ GlutenConfig.getConf.fallbackExpressionsThreshold)
+ ) {
+ return ValidationResult.failed("Fallback by complex expression")
+ }
+ ValidationResult.succeeded
+ }
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val totalTime = longMetric("time")
+ val c2r = longMetric("column_to_row_time")
+ val r2c = longMetric("row_to_column_time")
+ val isMutable = canUseMutableProjection()
+ child.executeColumnar().mapPartitions {
+ batches =>
+ val res: Iterator[Iterator[ColumnarBatch]] = new Iterator[Iterator[ColumnarBatch]] {
+ override def hasNext: Boolean = batches.hasNext
+
+ override def next(): Iterator[ColumnarBatch] = {
+ val batch = batches.next()
+ if (batch.numRows == 0) {
+ Iterator.empty
+ } else {
+ val start = System.currentTimeMillis()
+ val childData = ColumnarBatches.select(batch, projectIndexInChild.toArray)
+ val projectedBatch = if (isMutable) {
+ getProjectedBatchArrow(childData, c2r, r2c)
+ } else getProjectedBatch(childData, c2r, r2c)
+ val batchIterator = projectedBatch.map {
+ b =>
+ if (b.numCols() != 0) {
+ val compositeBatch = VeloxColumnarBatches.compose(batch, b)
+ b.close()
+ compositeBatch
+ } else {
+ b.close()
+ ColumnarBatches.retain(batch)
+ batch
+ }
+ }
+ childData.close()
+ totalTime += System.currentTimeMillis() - start
+ batchIterator
+ }
+ }
+ }
+ Iterators
+ .wrap(res.flatten)
+ .protectInvocationFlow() // Spark may call `hasNext()` again after a false output which
+ // is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator
+ .recyclePayload(_.close())
+ .create()
+
+ }
+ }
+
+ // scalastyle:off line.size.limit
+ // String type cannot use MutableProjection
+ // Otherwise will throw java.lang.UnsupportedOperationException: Datatype not supported StringType
+ // at org.apache.spark.sql.execution.vectorized.MutableColumnarRow.update(MutableColumnarRow.java:224)
+ // at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
+ // scalastyle:on line.size.limit
+ private def canUseMutableProjection(): Boolean = {
+ replacedAliasUdf.forall(
+ r =>
+ r.dataType match {
+ case StringType | BinaryType => false
+ case _ => true
+ })
+ }
+
+ /**
+ * add c2r and r2c for unsupported expression child data c2r get Iterator[InternalRow], then call
+ * Spark project, then r2c
+ */
+ private def getProjectedBatch(
+ childData: ColumnarBatch,
+ c2r: SQLMetric,
+ r2c: SQLMetric): Iterator[ColumnarBatch] = {
+ // select part of child output and child data
+ val proj = UnsafeProjection.create(replacedAliasUdf, projectAttributes.toSeq)
+ val numOutputRows = new SQLMetric("numOutputRows")
+ val numInputBatches = new SQLMetric("numInputBatches")
+ val rows = VeloxColumnarToRowExec
+ .toRowIterator(
+ Iterator.single[ColumnarBatch](childData),
+ projectAttributes.toSeq,
+ numOutputRows,
+ numInputBatches,
+ c2r)
+ .map(proj)
+
+ val schema =
+ SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
+ RowToVeloxColumnarExec.toColumnarBatchIterator(
+ rows,
+ schema,
+ numOutputRows,
+ numInputBatches,
+ r2c,
+ childData.numRows())
+ // TODO: should check the size <= 1, but now it has bug, will change iterator to empty
+ }
+
+ private def getProjectedBatchArrow(
+ childData: ColumnarBatch,
+ c2a: SQLMetric,
+ a2c: SQLMetric): Iterator[ColumnarBatch] = {
+ // select part of child output and child data
+ val proj = MutableProjection.create(replacedAliasUdf, projectAttributes.toSeq)
+ val numRows = childData.numRows()
+ val start = System.currentTimeMillis()
+ val arrowBatch = if (childData.numCols() == 0 || ColumnarBatches.isHeavyBatch(childData)) {
+ childData
+ } else ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), childData)
+ c2a += System.currentTimeMillis() - start
+
+ val schema =
+ SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
+ val vectors: Array[WritableColumnVector] = ArrowWritableColumnVector
+ .allocateColumns(numRows, schema)
+ .map {
+ vector =>
+ vector.setValueCount(numRows)
+ vector
+ }
+ val targetRow = new MutableColumnarRow(vectors)
+ for (i <- 0 until numRows) {
+ targetRow.rowId = i
+ proj.target(targetRow).apply(arrowBatch.getRow(i))
+ }
+ val targetBatch = new ColumnarBatch(vectors.map(_.asInstanceOf[ColumnVector]), numRows)
+ val start2 = System.currentTimeMillis()
+ val veloxBatch = VeloxColumnarBatches.toVeloxBatch(
+ ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), targetBatch))
+ a2c += System.currentTimeMillis() - start2
+ Iterators
+ .wrap(Iterator.single(veloxBatch))
+ .recycleIterator({
+ arrowBatch.close()
+ targetBatch.close()
+ })
+ .create()
+ // TODO: should check the size <= 1, but now it has bug, will change iterator to empty
+ }
+
+ override def verboseStringWithOperatorId(): String = {
+ s"""
+ |$formattedNodeName
+ |${ExplainUtils.generateFieldString("Output", output)}
+ |${ExplainUtils.generateFieldString("Input", child.output)}
+ |${ExplainUtils.generateFieldString("UDF", replacedAliasUdf)}
+ |${ExplainUtils.generateFieldString("ProjectOutput", projectAttributes)}
+ |${ExplainUtils.generateFieldString("ProjectInputIndex", projectIndexInChild)}
+ |""".stripMargin
+ }
+
+ override def simpleString(maxFields: Int): String =
+ super.simpleString(maxFields) + " PartialProject " + replacedAliasUdf
+
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarPartialProjectExec = {
+ copy(child = newChild)(replacedAliasUdf)
+ }
+}
+
+object ColumnarPartialProjectExec {
+
+ val projectPrefix = "_SparkPartialProject"
+
+ private def containsUDF(expr: Expression): Boolean = {
+ if (expr == null) return false
+ expr match {
+ case _: ScalaUDF => true
+ case h if HiveUdfUtil.isHiveUdf(h) => true
+ case p => p.children.exists(c => containsUDF(c))
+ }
+ }
+
+ private def replaceByAlias(expr: Expression, replacedAliasUdf: ListBuffer[Alias]): Expression = {
+ val replaceIndex = replacedAliasUdf.indexWhere(r => r.child.equals(expr))
+ if (replaceIndex == -1) {
+ val replace = Alias(expr, s"$projectPrefix${replacedAliasUdf.size}")()
+ replacedAliasUdf.append(replace)
+ replace.toAttribute
+ } else {
+ replacedAliasUdf(replaceIndex).toAttribute
+ }
+ }
+
+ private def isConditionalExpression(expr: Expression): Boolean = expr match {
+ case _: If => true
+ case _: CaseWhen => true
+ case _: NaNvl => true
+ case _: Coalesce => true
+ case _ => false
+ }
+
+ private def replaceExpressionUDF(
+ expr: Expression,
+ replacedAliasUdf: ListBuffer[Alias]): Expression = {
+ if (expr == null) return null
+ expr match {
+ case u: ScalaUDF =>
+ replaceByAlias(u, replacedAliasUdf)
+ case h if HiveUdfUtil.isHiveUdf(h) =>
+ replaceByAlias(h, replacedAliasUdf)
+ case au @ Alias(_: ScalaUDF, _) =>
+ val replaceIndex = replacedAliasUdf.indexWhere(r => r.exprId == au.exprId)
+ if (replaceIndex == -1) {
+ replacedAliasUdf.append(au)
+ au.toAttribute
+ } else {
+ replacedAliasUdf(replaceIndex).toAttribute
+ }
+ // Alias(HiveSimpleUDF) not exists, only be Alias(ToPrettyString(HiveSimpleUDF)),
+ // so don't process this condition
+ case x if isConditionalExpression(x) =>
+ // For example:
+ // myudf is udf((x: Int) => x + 1)
+ // if (isnull(cast(l_extendedprice#9 as bigint))) null
+ // else myudf(knownnotnull(cast(l_extendedprice#9 as bigint)))
+ // if we extract else branch, and use the data child l_extendedprice,
+ // the result is incorrect for null value
+ if (containsUDF(expr)) {
+ replaceByAlias(expr, replacedAliasUdf)
+ } else expr
+ case p => p.withNewChildren(p.children.map(c => replaceExpressionUDF(c, replacedAliasUdf)))
+ }
+ }
+
+ def create(original: ProjectExec): ProjectExecTransformer = {
+ val replacedAliasUdf: ListBuffer[Alias] = ListBuffer()
+ val newProjectList = original.projectList.map {
+ p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression]
+ }
+ val partialProject =
+ ColumnarPartialProjectExec(original, original.child)(replacedAliasUdf.toSeq)
+ ProjectExecTransformer(newProjectList, partialProject)
+ }
+}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala
index 9118e0d..639113e 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala
@@ -112,7 +112,6 @@
convertTime
)
}
-
def toRowIterator(
batches: Iterator[ColumnarBatch],
output: Seq[Attribute],
diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala
new file mode 100644
index 0000000..066f09d
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.execution.ColumnarPartialProjectExec
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+
+case class PartialProjectRule(spark: SparkSession) extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ plan.transformUp {
+ case plan: ProjectExec =>
+ val transformer = ColumnarPartialProjectExec.create(plan)
+ if (transformer.doValidate().ok()) {
+ if (transformer.child.asInstanceOf[ColumnarPartialProjectExec].doValidate().ok()) {
+ transformer
+ } else plan
+ } else plan
+ case p => p
+ }
+ }
+}
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/hive/HiveUdfUtil.scala b/backends-velox/src/main/scala/org/apache/spark/sql/hive/HiveUdfUtil.scala
new file mode 100644
index 0000000..7686455
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/hive/HiveUdfUtil.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+object HiveUdfUtil {
+ def isHiveUdf(expr: Expression): Boolean = expr match {
+ case _: HiveSimpleUDF => true
+ case _: HiveGenericUDF => true
+ case _: HiveUDAFFunction => true
+ case _: HiveGenericUDTF => true
+ case _ => false
+ }
+
+}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
new file mode 100644
index 0000000..cd3d0c5
--- /dev/null
+++ b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.execution.{ColumnarPartialProjectExec, WholeStageTransformerSuite}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation}
+import org.apache.spark.sql.functions.udf
+
+import java.io.File
+
+class UDFPartialProjectSuite extends WholeStageTransformerSuite {
+ disableFallbackCheck
+ override protected val resourcePath: String = "/tpch-data-parquet-velox"
+ override protected val fileFormat: String = "parquet"
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+ .set("spark.sql.files.maxPartitionBytes", "1g")
+ .set("spark.sql.shuffle.partitions", "1")
+ .set("spark.memory.offHeap.size", "2g")
+ .set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ .set("spark.sql.autoBroadcastJoinThreshold", "-1")
+ .set("spark.sql.sources.useV1SourceList", "avro")
+ .set(
+ "spark.sql.optimizer.excludedRules",
+ ConstantFolding.ruleName + "," +
+ NullPropagation.ruleName)
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val table = "lineitem"
+ val tableDir = getClass.getResource(resourcePath).getFile
+ val tablePath = new File(tableDir, table).getAbsolutePath
+ val tableDF = spark.read.format(fileFormat).load(tablePath)
+ tableDF.createOrReplaceTempView(table)
+
+ val plusOne = udf((x: Long) => x + 1)
+ spark.udf.register("plus_one", plusOne)
+ val noArgument = udf(() => 15)
+ spark.udf.register("no_argument", noArgument)
+
+ }
+
+ ignore("test plus_one") {
+ runQueryAndCompare("SELECT sum(plus_one(cast(l_orderkey as long))) from lineitem") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ ignore("test plus_one with column used twice") {
+ runQueryAndCompare(
+ "SELECT sum(plus_one(cast(l_orderkey as long)) + hash(l_orderkey)) from lineitem") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ ignore("test plus_one without cast") {
+ runQueryAndCompare("SELECT sum(plus_one(l_orderkey) + hash(l_orderkey)) from lineitem") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("test plus_one with many columns") {
+ runQueryAndCompare(
+ "SELECT sum(plus_one(cast(l_orderkey as long)) + hash(l_partkey))" +
+ "from lineitem " +
+ "where l_orderkey < 3") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("test plus_one with many columns in project") {
+ runQueryAndCompare("SELECT plus_one(cast(l_orderkey as long)), hash(l_partkey) from lineitem") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("test function no argument") {
+ runQueryAndCompare("""SELECT no_argument(), l_orderkey
+ | from lineitem limit 100""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("test nondeterministic function input_file_name") {
+ val df = spark.sql("""SELECT input_file_name(), l_orderkey
+ | from lineitem limit 100""".stripMargin)
+ df.collect()
+ assert(
+ df.queryExecution.executedPlan
+ .find(p => p.isInstanceOf[ColumnarPartialProjectExec])
+ .isEmpty)
+ }
+
+ test("udf in agg simple") {
+ runQueryAndCompare("""select sum(hash(plus_one(l_extendedprice)) + hash(l_orderkey) ) as revenue
+ | from lineitem""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("udf in agg") {
+ runQueryAndCompare("""select sum(hash(plus_one(l_extendedprice)) * l_discount
+ | + hash(l_orderkey) + hash(l_comment)) as revenue
+ | from lineitem""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+}
diff --git a/gluten-arrow/src/main/scala/org/apache/spark/sql/utils/SparkArrowUtil.scala b/gluten-arrow/src/main/scala/org/apache/spark/sql/utils/SparkArrowUtil.scala
index 652fbeb..da3f5c0 100644
--- a/gluten-arrow/src/main/scala/org/apache/spark/sql/utils/SparkArrowUtil.scala
+++ b/gluten-arrow/src/main/scala/org/apache/spark/sql/utils/SparkArrowUtil.scala
@@ -156,6 +156,17 @@
}.asJava)
}
+ // TimestampNTZ does not support
+ def checkSchema(schema: StructType): Boolean = {
+ try {
+ SparkSchemaUtil.toArrowSchema(schema)
+ true
+ } catch {
+ case _: Exception =>
+ false
+ }
+ }
+
def fromArrowSchema(schema: Schema): StructType = {
StructType(schema.getFields.asScala.toSeq.map {
field =>
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionUtils.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionUtils.scala
index c0920cd..db129c7 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionUtils.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionUtils.scala
@@ -17,10 +17,11 @@
package org.apache.gluten.expression
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression}
+import org.apache.spark.sql.execution.SparkPlan
object ExpressionUtils {
- def getExpressionTreeDepth(expr: Expression): Integer = {
+ private def getExpressionTreeDepth(expr: Expression): Integer = {
if (expr.isInstanceOf[LeafExpression]) {
return 0
}
@@ -31,4 +32,8 @@
1 + childrenDepth.max
}
}
+
+ def isComplexExpression(plan: SparkPlan, threshold: Int): Boolean = {
+ plan.expressions.exists(e => ExpressionUtils.getExpressionTreeDepth(e) > threshold)
+ }
}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
index f1cb479..404df5e 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
@@ -118,7 +118,7 @@
private class FallbackComplexExpressions(threshold: Int) extends Validator {
override def validate(plan: SparkPlan): Validator.OutCome = {
- if (plan.expressions.exists(e => ExpressionUtils.getExpressionTreeDepth(e) > threshold)) {
+ if (ExpressionUtils.isComplexExpression(plan, threshold)) {
return fail(
s"Disabled because at least one present expression exceeded depth threshold: " +
s"${plan.nodeName}")
diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index 1e06aea..95391a2 100644
--- a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -316,7 +316,6 @@
}
}
checkDataFrame(noFallBack, customCheck, df)
- df.explain(true)
df
}
diff --git a/gluten-ut/pom.xml b/gluten-ut/pom.xml
index af46f75..89d1ff9 100644
--- a/gluten-ut/pom.xml
+++ b/gluten-ut/pom.xml
@@ -81,6 +81,11 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <type>test-jar</type>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
diff --git a/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java b/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
new file mode 100644
index 0000000..257bd07
--- /dev/null
+++ b/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+/**
+ * UDF that generates a the link id (MD5 hash) of a URL. Used to join with link join.
+ *
+ * <p>Usage example:
+ *
+ * <p>CREATE TEMPORARY FUNCTION linkid AS 'com.pinterest.hadoop.hive.LinkIdUDF';
+ */
+@Description(
+ name = "linkid",
+ value = "_FUNC_(String) - Returns linkid as String, it's the MD5 hash of url.")
+public class CustomerUDF extends UDF {
+ public String evaluate(String url) {
+ if (url == null || url == "") {
+ return "";
+ }
+ return "extendedudf" + url;
+ }
+}
diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
new file mode 100644
index 0000000..cc9f6f1
--- /dev/null
+++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive.execution
+
+import org.apache.gluten.execution.CustomerUDF
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.internal.config
+import org.apache.spark.internal.config.UI.UI_ENABLED
+import org.apache.spark.sql.{GlutenTestsBaseTrait, QueryTest, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
+import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
+import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.hive.test.TestHiveContext
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
+import org.apache.spark.sql.test.SQLTestUtils
+
+import org.scalatest.BeforeAndAfterAll
+
+import java.io.File
+
+trait GlutenTestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
+ override protected val enableAutoThreadAudit = false
+
+}
+
+object GlutenTestHive
+ extends TestHiveContext(
+ new SparkContext(
+ System.getProperty("spark.sql.test.master", "local[1]"),
+ "TestSQLContext",
+ new SparkConf()
+ .set("spark.sql.test", "")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+ .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+ .set(
+ HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+ "org.apache.spark.sql.hive.execution.PairSerDe")
+ .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
+ // SPARK-8910
+ .set(UI_ENABLED, false)
+ .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+ // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
+ // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
+ .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
+ .set("spark.driver.memory", "1G")
+ .set("spark.sql.adaptive.enabled", "true")
+ .set("spark.sql.shuffle.partitions", "1")
+ .set("spark.sql.files.maxPartitionBytes", "134217728")
+ .set("spark.memory.offHeap.enabled", "true")
+ .set("spark.memory.offHeap.size", "1024MB")
+ .set("spark.plugins", "org.apache.gluten.GlutenPlugin")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+ // Disable ConvertToLocalRelation for better test coverage. Test cases built on
+ // LocalRelation will exercise the optimization rules better by disabling it as
+ // this rule may potentially block testing of other optimization rules such as
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+ ),
+ false
+ ) {}
+
+class GlutenHiveUDFSuite
+ extends QueryTest
+ with GlutenTestHiveSingleton
+ with SQLTestUtils
+ with GlutenTestsBaseTrait {
+ override protected val spark: SparkSession = GlutenTestHive.sparkSession
+ protected val hiveContext: TestHiveContext = GlutenTestHive
+ protected val hiveClient: HiveClient =
+ spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client
+
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+ val table = "lineitem"
+ val tableDir =
+ getClass.getResource("").getPath + "/../../../../../../../../../../../" +
+ "/backends-velox/src/test/resources/tpch-data-parquet-velox/"
+ val tablePath = new File(tableDir, table).getAbsolutePath
+ val tableDF = spark.read.format("parquet").load(tablePath)
+ tableDF.createOrReplaceTempView(table)
+ }
+
+ override protected def afterAll(): Unit = {
+ try {
+ hiveContext.reset()
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ override protected def shouldRun(testName: String): Boolean = {
+ false
+ }
+
+ test("customer udf") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'")
+ val df = spark.sql("""select testUDF(l_comment)
+ | from lineitem""".stripMargin)
+ df.show()
+ print(df.queryExecution.executedPlan)
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
+ hiveContext.reset()
+ }
+
+ test("customer udf wrapped in function") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'")
+ val df = spark.sql("""select hash(testUDF(l_comment))
+ | from lineitem""".stripMargin)
+ df.show()
+ print(df.queryExecution.executedPlan)
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
+ hiveContext.reset()
+ }
+
+ test("example") {
+ spark.sql("CREATE TEMPORARY FUNCTION testUDF AS 'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
+ spark.sql("select testUDF('l_commen', 1, 5)").show()
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
+ hiveContext.reset()
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index a5f39b9..a3ce1c6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -772,6 +772,13 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.version}</version>
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 5ca6a15..12b6128 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -445,6 +445,8 @@
def enableColumnarProjectCollapse: Boolean = conf.getConf(ENABLE_COLUMNAR_PROJECT_COLLAPSE)
+ def enableColumnarPartialProject: Boolean = conf.getConf(ENABLE_COLUMNAR_PARTIAL_PROJECT)
+
def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL)
def awsS3RetryMode: String = conf.getConf(AWS_S3_RETRY_MODE)
@@ -1878,6 +1880,17 @@
.booleanConf
.createWithDefault(true)
+ val ENABLE_COLUMNAR_PARTIAL_PROJECT =
+ buildConf("spark.gluten.sql.columnar.partial.project")
+ .doc(
+ "Break up one project node into 2 phases when some of the expressions are non " +
+ "offload-able. Phase one is a regular offloaded project transformer that " +
+ "evaluates the offload-able expressions in native, " +
+ "phase two preserves the output from phase one and evaluates the remaining " +
+ "non-offload-able expressions using vanilla Spark projections")
+ .booleanConf
+ .createWithDefault(true)
+
val ENABLE_COMMON_SUBEXPRESSION_ELIMINATE =
buildConf("spark.gluten.sql.commonSubexpressionEliminate")
.internal()