[SPARK-49477][PYTHON] Improve pandas udf invalid return type error message
### What changes were proposed in this pull request?
This PR improves the error message when the specified return type of a pandas udf mismatch the actual return type.
### Why are the changes needed?
To improve the error message.
Before this PR:
`pyspark.errors.exceptions.base.PySparkValueError: A field of type StructType expects a pandas.DataFrame, but got: <class 'pandas.core.series.Series'>`
After this PR:
`pyspark.errors.exceptions.base.PySparkValueError: Invalid return type. Please make sure that the UDF returns a pandas.DataFrame when the specified return type is StructType.`
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47942 from allisonwang-db/spark-49477-pandas-udf-err-msg.
Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 6203d4d..0762268 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -510,8 +510,8 @@
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
- "A field of type StructType expects a pandas.DataFrame, "
- "but got: %s" % str(type(s))
+ "Invalid return type. Please make sure that the UDF returns a "
+ "pandas.DataFrame when the specified return type is StructType."
)
arrs.append(self._create_struct_array(s, t))
else:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 6720dfc..228fc30 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -339,6 +339,19 @@
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
+ def test_pandas_udf_return_type_error(self):
+ import pandas as pd
+
+ @pandas_udf("s string")
+ def upper(s: pd.Series) -> pd.Series:
+ return s.str.upper()
+
+ df = self.spark.createDataFrame([("a",)], schema="s string")
+
+ self.assertRaisesRegex(
+ PythonException, "Invalid return type", df.select(upper("s")).collect
+ )
+
class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
pass