[SPARK-48845][SQL] GenericUDF catch exceptions from children
### What changes were proposed in this pull request?
This pr is trying to fix the syntax issues with GenericUDF since 3.5.0. The problem arose from DeferredObject currently passing a value instead of a function, which prevented users from catching exceptions in GenericUDF, resulting in semantic differences.
Here is an example case we encountered. Originally, the semantics were that udf_exception would throw an exception, while udf_catch_exception could catch the exception and return a null value. However, currently, any exception encountered by udf_exception will cause the program to fail.
```
select udf_catch_exception(udf_exception(col1)) from table
```
### Why are the changes needed?
For before Spark 3.5, we directly made the GenericUDF's DeferredObject lazy and evaluated the children in `function.evaluate(deferredObjects)`.
Now, we would run the children's code first. If an exception is thrown, we would make it lazy to GenericUDF's DeferredObject.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Newly added UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47268 from jackylee-ch/generic_udf_catch_exception_from_child_func.
Lead-authored-by: jackylee-ch <lijunqing@baidu.com>
Co-authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
(cherry picked from commit 236d95738b6e50bc9ec54955e86d01b6dcf11c0e)
Signed-off-by: Kent Yao <yao@apache.org>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
index 094f8ba..fc1c795 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
@@ -129,7 +129,11 @@
override def returnType: DataType = inspectorToDataType(returnInspector)
def setArg(index: Int, arg: Any): Unit =
- deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
+ deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg)
+
+ def setException(index: Int, exp: Throwable): Unit = {
+ deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw exp)
+ }
override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects))
}
@@ -139,10 +143,10 @@
extends DeferredObject with HiveInspectors {
private val wrapper = wrapperFor(oi, dataType)
- private var func: Any = _
- def set(func: Any): Unit = {
+ private var func: () => Any = _
+ def set(func: () => Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
- override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
+ override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 01684f5..0c8305b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -136,7 +136,13 @@
override def eval(input: InternalRow): Any = {
children.zipWithIndex.foreach {
- case (child, idx) => evaluator.setArg(idx, child.eval(input))
+ case (child, idx) =>
+ try {
+ evaluator.setArg(idx, child.eval(input))
+ } catch {
+ case t: Throwable =>
+ evaluator.setException(idx, t)
+ }
}
evaluator.evaluate()
}
@@ -157,10 +163,15 @@
val setValues = evals.zipWithIndex.map {
case (eval, i) =>
s"""
- |if (${eval.isNull}) {
- | $refEvaluator.setArg($i, null);
- |} else {
- | $refEvaluator.setArg($i, ${eval.value});
+ |try {
+ | ${eval.code}
+ | if (${eval.isNull}) {
+ | $refEvaluator.setArg($i, null);
+ | } else {
+ | $refEvaluator.setArg($i, ${eval.value});
+ | }
+ |} catch (Throwable t) {
+ | $refEvaluator.setException($i, t);
|}
|""".stripMargin
}
@@ -169,7 +180,6 @@
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
- |${evals.map(_.code).mkString("\n")}
|${setValues.mkString("\n")}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
|boolean ${ev.isNull} = $resultTerm == null;
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
new file mode 100644
index 0000000..242dbea
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+public class UDFCatchException extends GenericUDF {
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
+ if (args.length != 1) {
+ throw new UDFArgumentException("Exactly one argument is expected.");
+ }
+ return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] args) {
+ if (args == null) {
+ return null;
+ }
+ try {
+ return args[0].get();
+ } catch (Exception e) {
+ return null;
+ }
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return null;
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
new file mode 100644
index 0000000..5d6ff6c
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFThrowException extends UDF {
+ public String evaluate(String data) {
+ return Integer.valueOf(data).toString();
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d12ebae..f3be79f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -35,6 +35,7 @@
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.{call_function, max}
@@ -791,6 +792,28 @@
}
}
}
+
+ test("SPARK-48845: GenericUDF catch exceptions from child UDFs") {
+ withTable("test_catch_exception") {
+ withUserDefinedFunction("udf_throw_exception" -> true, "udf_catch_exception" -> true) {
+ Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception")
+ sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " +
+ s"'${classOf[UDFThrowException].getName}'")
+ sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " +
+ s"'${classOf[UDFCatchException].getName}'")
+ Seq(
+ CodegenObjectFactoryMode.FALLBACK.toString,
+ CodegenObjectFactoryMode.NO_CODEGEN.toString
+ ).foreach { codegenMode =>
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
+ val df = sql(
+ "SELECT udf_catch_exception(udf_throw_exception(a)) FROM test_catch_exception")
+ checkAnswer(df, Seq(Row("9"), Row(null)))
+ }
+ }
+ }
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {