blob: ca1f828ef4d78d2e1f1c8b842f645399845b4020 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import inspect
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.classic.column import Column as ClassicColumn
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.session import SparkSession as ClassicSparkSession
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
class ConnectCompatibilityTestsMixin:
def get_public_methods(self, cls):
"""Get public methods of a class."""
return {
name: method
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction)
if not name.startswith("_")
}
def get_public_properties(self, cls):
"""Get public properties of a class."""
return {
name: member
for name, member in inspect.getmembers(cls)
if isinstance(member, property) and not name.startswith("_")
}
def test_signature_comparison_between_classic_and_connect(self):
def compare_method_signatures(classic_cls, connect_cls, cls_name):
"""Compare method signatures between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)
common_methods = set(classic_methods.keys()) & set(connect_methods.keys())
for method in common_methods:
classic_signature = inspect.signature(classic_methods[method])
connect_signature = inspect.signature(connect_methods[method])
# createDataFrame cannot be the same since RDD is not supported from Spark Connect
if not method == "createDataFrame":
self.assertEqual(
classic_signature,
connect_signature,
f"Signature mismatch in {cls_name} method '{method}'\n"
f"Classic: {classic_signature}\n"
f"Connect: {connect_signature}",
)
# DataFrame API signature comparison
compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame")
# Column API signature comparison
compare_method_signatures(ClassicColumn, ConnectColumn, "Column")
# SparkSession API signature comparison
compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession")
def test_property_comparison_between_classic_and_connect(self):
def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties):
"""Compare properties between classic and connect classes."""
classic_properties = self.get_public_properties(classic_cls)
connect_properties = self.get_public_properties(connect_cls)
# Identify missing properties
classic_only_properties = set(classic_properties.keys()) - set(
connect_properties.keys()
)
# Compare the actual missing properties with the expected ones
self.assertEqual(
classic_only_properties,
expected_missing_properties,
f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}",
)
# Expected missing properties for DataFrame
expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"}
# DataFrame properties comparison
compare_property_lists(
ClassicDataFrame,
ConnectDataFrame,
"DataFrame",
expected_missing_properties_for_dataframe,
)
# Expected missing properties for Column (if any, replace with actual values)
expected_missing_properties_for_column = set()
# Column properties comparison
compare_property_lists(
ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column
)
# Expected missing properties for SparkSession
expected_missing_properties_for_spark_session = {"sparkContext", "version"}
# SparkSession properties comparison
compare_property_lists(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_properties_for_spark_session,
)
def test_missing_methods(self):
def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods):
"""Check for expected missing methods between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)
# Identify missing methods
classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys())
# Compare the actual missing methods with the expected ones
self.assertEqual(
classic_only_methods,
expected_missing_methods,
f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}",
)
# Expected missing methods for DataFrame
expected_missing_methods_for_dataframe = {
"inputFiles",
"isLocal",
"semanticHash",
"isEmpty",
}
# DataFrame missing method check
check_missing_methods(
ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe
)
# Expected missing methods for Column (if any, replace with actual values)
expected_missing_methods_for_column = set()
# Column missing method check
check_missing_methods(
ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column
)
# Expected missing methods for SparkSession (if any, replace with actual values)
expected_missing_methods_for_spark_session = {"newSession"}
# SparkSession missing method check
check_missing_methods(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_methods_for_spark_session,
)
class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_connect_compatibility import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)