[SPARK-48190][PYTHON][PS][TESTS] Introduce a helper function to drop metadata
### What changes were proposed in this pull request?
Introduce a helper function to drop metadata
### Why are the changes needed?
existing helper function `remove_metadata` in PS doesn't support nested types, so cannot be reused in other places
### Does this PR introduce _any_ user-facing change?
no, test only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46466 from zhengruifeng/py_drop_meta.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py
index 767ec9a..8ab8d79 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -33,6 +33,7 @@
Window,
)
from pyspark.sql.types import ( # noqa: F401
+ _drop_metadata,
BooleanType,
DataType,
LongType,
@@ -761,14 +762,8 @@
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
- def remove_metadata(struct_field: StructField) -> StructField:
- new_struct_field = StructField(
- struct_field.name, struct_field.dataType, struct_field.nullable
- )
- return new_struct_field
-
assert all(
- remove_metadata(index_field.struct_field) == remove_metadata(struct_field)
+ _drop_metadata(index_field.struct_field) == _drop_metadata(struct_field)
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)
else:
@@ -795,14 +790,8 @@
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
- def remove_metadata(struct_field: StructField) -> StructField:
- new_struct_field = StructField(
- struct_field.name, struct_field.dataType, struct_field.nullable
- )
- return new_struct_field
-
assert all(
- remove_metadata(data_field.struct_field) == remove_metadata(struct_field)
+ _drop_metadata(data_field.struct_field) == _drop_metadata(struct_field)
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)
else:
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index 9d4db8c..0f0abfd 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -21,7 +21,14 @@
from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql import SparkSession as PySparkSession
-from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType
+from pyspark.sql.types import (
+ _drop_metadata,
+ StringType,
+ StructType,
+ StructField,
+ ArrayType,
+ IntegerType,
+)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
@@ -1668,7 +1675,7 @@
)
# TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 'true'}'
- # self.assertEqual(cdf.schema, sdf.schema)
+ self.assertEqual(_drop_metadata(cdf.schema), _drop_metadata(sdf.schema))
self.assertEqual(cdf.collect(), sdf.collect())
def test_csv_functions(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 48aa3e8..41be126 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1569,6 +1569,19 @@
_INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
+def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]:
+ assert isinstance(d, (DataType, StructField))
+ if isinstance(d, StructField):
+ return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None)
+ elif isinstance(d, StructType):
+ return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields])
+ elif isinstance(d, ArrayType):
+ return ArrayType(_drop_metadata(d.elementType), d.containsNull)
+ elif isinstance(d, MapType):
+ return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull)
+ return d
+
+
def _parse_datatype_string(s: str) -> DataType:
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals