[SPARK-48197][SQL] Avoid assert error for invalid lambda function
### What changes were proposed in this pull request?
`ExpressionBuilder` asserts all its input expressions to be resolved during lookup, which is not true as the analyzer rule `ResolveFunctions` can trigger function lookup even if the input expression contains unresolved lambda functions.
This PR updates that assert to check non-lambda inputs only, and fail earlier if the input contains lambda functions. In the future, if we use `ExpressionBuilder` to register higher-order functions, we can relax it.
### Why are the changes needed?
better error message
### Does this PR introduce _any_ user-facing change?
no, only changes error message
### How was this patch tested?
new test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46475 from cloud-fan/minor.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6565591..f37f47c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -955,7 +955,14 @@
since: Option[String] = None): (String, (ExpressionInfo, FunctionBuilder)) = {
val info = FunctionRegistryBase.expressionInfo[T](name, since)
val funcBuilder = (expressions: Seq[Expression]) => {
- assert(expressions.forall(_.resolved), "function arguments must be resolved.")
+ val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction])
+ if (lambdas.nonEmpty && !builder.supportsLambda) {
+ throw new AnalysisException(
+ errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+ messageParameters = Map(
+ "class" -> builder.getClass.getCanonicalName))
+ }
+ assert(others.forall(_.resolved), "function arguments must be resolved.")
val rearrangedExpressions = rearrangeExpressions(name, builder, expressions)
val expr = builder.build(name, rearrangedExpressions)
if (setAlias) expr.setTagValue(FUNC_ALIAS, name)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
index 7e04af1..0aa73f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
@@ -70,6 +70,8 @@
}
def build(funcName: String, expressions: Seq[Expression]): T
+
+ def supportsLambda: Boolean = false
}
object NamedParametersSupport {
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
index 693cb2a..a772a2c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out
@@ -36,6 +36,26 @@
-- !query
+select ceil(x -> x) as v
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+ "sqlState" : "42K0D",
+ "messageParameters" : {
+ "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "ceil(x -> x)"
+ } ]
+}
+
+
+-- !query
select transform(zs, z -> z) as v from nested
-- !query analysis
Project [transform(zs#x, lambdafunction(lambda z#x, lambda z#x, false)) AS v#x]
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
index ec6d727..c82ba7d 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out
@@ -36,6 +36,26 @@
-- !query
+select ceil(x -> x) as v
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+ "sqlState" : "42K0D",
+ "messageParameters" : {
+ "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "ceil(x -> x)"
+ } ]
+}
+
+
+-- !query
select transform(zs, z -> z) as v from nested
-- !query analysis
Project [transform(zs#x, lambdafunction(lambda z#x, lambda z#x, false)) AS v#x]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index 7925a21..37081de 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -11,6 +11,8 @@
-- Only allow lambda's in higher order functions.
select upper(x -> x) as v;
+-- Also test functions registered with `ExpressionBuilder`.
+select ceil(x -> x) as v;
-- Identity transform an array
select transform(zs, z -> z) as v from nested;
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
index ee45252..7bfc35a 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
@@ -34,6 +34,28 @@
-- !query
+select ceil(x -> x) as v
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+ "sqlState" : "42K0D",
+ "messageParameters" : {
+ "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "ceil(x -> x)"
+ } ]
+}
+
+
+-- !query
select transform(zs, z -> z) as v from nested
-- !query schema
struct<v:array<array<int>>>
diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index ee45252..7bfc35a 100644
--- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -34,6 +34,28 @@
-- !query
+select ceil(x -> x) as v
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
+ "sqlState" : "42K0D",
+ "messageParameters" : {
+ "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "ceil(x -> x)"
+ } ]
+}
+
+
+-- !query
select transform(zs, z -> z) as v from nested
-- !query schema
struct<v:array<array<int>>>