[SPARK-48228][PYTHON][CONNECT] Implement the missing function validation in ApplyInXXX
### What changes were proposed in this pull request?
Implement the missing function validation in ApplyInXXX
https://github.com/apache/spark/pull/46397 fixed this issue for `Cogrouped.ApplyInPandas`, this PR fix remaining methods.
### Why are the changes needed?
for better error message:
```
In [12]: df1 = spark.range(11)
In [13]: df2 = df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
In [14]: df2.show()
```
before this PR, an invalid function causes weird execution errors:
```
24/05/10 11:37:36 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 36)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1834, in main
process()
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1826, in process
serializer.dump_stream(out_iter, outfile)
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 531, in dump_stream
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 104, in dump_stream
for batch in iterator:
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 524, in init_stream_yield_batches
for series in iterator:
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1610, in mapper
return f(keys, vals)
^^^^^^^^^^^^^
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 488, in <lambda>
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
^^^^^^^^^^^^^
File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 483, in wrapped
result, return_type, _assign_cols_by_name, truncate_return_schema=False
^^^^^^
UnboundLocalError: cannot access local variable 'result' where it is not associated with a value
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:523)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:479)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:896)
...
```
After this PR, the error happens before execution, which is consistent with Spark Classic, and
much clear
```
PySparkValueError: [INVALID_PANDAS_UDF] Invalid function: pandas_udf with function type GROUPED_MAP or the function in groupby.applyInPandas must take either one argument (data) or two arguments (key, data).
```
### Does this PR introduce _any_ user-facing change?
yes, error message changes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46519 from zhengruifeng/missing_check_in_group.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index c916e8a..2a5bb59 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -34,6 +34,7 @@
from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
+from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
from pyspark.sql.types import NumericType
from pyspark.sql.types import StructType
@@ -293,6 +294,7 @@
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
+ _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -322,6 +324,7 @@
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
+ _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE)
udf_obj = UserDefinedFunction(
func,
returnType=outputStructType,
@@ -360,6 +363,7 @@
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
+ _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -398,9 +402,8 @@
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
- from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
- _validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
+ _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -426,6 +429,7 @@
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
+ _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py
index 5922a5c..020105b 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -432,7 +432,7 @@
# validate the pandas udf and return the adjusted eval type
-def _validate_pandas_udf(f, returnType, evalType) -> int:
+def _validate_pandas_udf(f, evalType) -> int:
argspec = getfullargspec(f)
# pandas UDF by type hints.
@@ -533,7 +533,7 @@
def _create_pandas_udf(f, returnType, evalType):
- evalType = _validate_pandas_udf(f, returnType, evalType)
+ evalType = _validate_pandas_udf(f, evalType)
if is_remote():
from pyspark.sql.connect.udf import _create_udf as _create_connect_udf
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 1e86e12..a26d6d0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -439,6 +439,26 @@
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))
)
+ def test_wrong_args_in_apply_func(self):
+ df1 = self.spark.range(11)
+ df2 = self.spark.range(22)
+
+ with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+ df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
+
+ with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+ df1.groupby("id").applyInArrow(lambda: 1, StructType([StructField("d", DoubleType())]))
+
+ with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+ df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
+ lambda: 1, StructType([StructField("d", DoubleType())])
+ )
+
+ with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+ df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
+ lambda: 1, StructType([StructField("d", DoubleType())])
+ )
+
def test_unsupported_types(self):
with self.quiet():
self.check_unsupported_types()