#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
import inspect
import functools

from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
from pyspark.sql.classic.column import Column as ClassicColumn
from pyspark.sql.session import SparkSession as ClassicSparkSession
from pyspark.sql.catalog import Catalog as ClassicCatalog
from pyspark.sql.readwriter import DataFrameReader as ClassicDataFrameReader
from pyspark.sql.readwriter import DataFrameWriter as ClassicDataFrameWriter
from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2
from pyspark.sql.window import Window as ClassicWindow
from pyspark.sql.window import WindowSpec as ClassicWindowSpec
import pyspark.sql.functions as ClassicFunctions
from pyspark.sql.group import GroupedData as ClassicGroupedData
import pyspark.sql.avro.functions as ClassicAvro
import pyspark.sql.protobuf.functions as ClassicProtobuf
from pyspark.sql.streaming.query import StreamingQuery as ClassicStreamingQuery
from pyspark.sql.streaming.query import StreamingQueryManager as ClassicStreamingQueryManager
from pyspark.sql.streaming.readwriter import DataStreamReader as ClassicDataStreamReader
from pyspark.sql.streaming.readwriter import DataStreamWriter as ClassicDataStreamWriter

if should_test_connect:
    from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
    from pyspark.sql.connect.column import Column as ConnectColumn
    from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
    from pyspark.sql.connect.catalog import Catalog as ConnectCatalog
    from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
    from pyspark.sql.connect.readwriter import DataFrameWriter as ConnectDataFrameWriter
    from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2
    from pyspark.sql.connect.window import Window as ConnectWindow
    from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec
    import pyspark.sql.connect.functions as ConnectFunctions
    from pyspark.sql.connect.group import GroupedData as ConnectGroupedData
    import pyspark.sql.connect.avro.functions as ConnectAvro
    import pyspark.sql.connect.protobuf.functions as ConnectProtobuf
    from pyspark.sql.connect.streaming.query import StreamingQuery as ConnectStreamingQuery
    from pyspark.sql.connect.streaming.query import (
        StreamingQueryManager as ConnectStreamingQueryManager,
    )
    from pyspark.sql.connect.streaming.readwriter import DataStreamReader as ConnectDataStreamReader
    from pyspark.sql.connect.streaming.readwriter import DataStreamWriter as ConnectDataStreamWriter


class ConnectCompatibilityTestsMixin:
    def get_public_methods(self, cls):
        """Get public methods of a class."""
        methods = {}
        for name, method in inspect.getmembers(cls):
            if (
                inspect.isfunction(method) or isinstance(method, functools._lru_cache_wrapper)
            ) and not name.startswith("_"):
                if getattr(method, "_remote_only", False):
                    methods[name] = None
                else:
                    methods[name] = method
        return methods

    def get_public_properties(self, cls):
        """Get public properties of a class."""
        return {
            name: member
            for name, member in inspect.getmembers(cls)
            if (isinstance(member, property) or isinstance(member, functools.cached_property))
            and not name.startswith("_")
        }

    def compare_method_signatures(self, classic_cls, connect_cls, cls_name):
        """Compare method signatures between classic and connect classes."""
        classic_methods = self.get_public_methods(classic_cls)
        connect_methods = self.get_public_methods(connect_cls)

        common_methods = set(classic_methods.keys()) & set(connect_methods.keys())

        for method in common_methods:
            # Skip non-callable, Spark Connect-specific methods
            if classic_methods[method] is None or connect_methods[method] is None:
                continue

            classic_signature = inspect.signature(classic_methods[method])
            connect_signature = inspect.signature(connect_methods[method])

            # Cannot support RDD arguments from Spark Connect
            has_rdd_arguments = ("createDataFrame", "xml", "json")
            if method not in has_rdd_arguments:
                self.assertEqual(
                    classic_signature,
                    connect_signature,
                    f"Signature mismatch in {cls_name} method '{method}'\n"
                    f"Classic: {classic_signature}\n"
                    f"Connect: {connect_signature}",
                )

    def compare_property_lists(
        self,
        classic_cls,
        connect_cls,
        cls_name,
        expected_missing_connect_properties,
        expected_missing_classic_properties,
    ):
        """Compare properties between classic and connect classes."""
        classic_properties = self.get_public_properties(classic_cls)
        connect_properties = self.get_public_properties(connect_cls)

        # Identify missing properties
        classic_only_properties = set(classic_properties.keys()) - set(connect_properties.keys())
        connect_only_properties = set(connect_properties.keys()) - set(classic_properties.keys())

        # Compare the actual missing properties with the expected ones
        self.assertEqual(
            classic_only_properties,
            expected_missing_connect_properties,
            f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}",
        )

        # Reverse compatibility check
        self.assertEqual(
            connect_only_properties,
            expected_missing_classic_properties,
            f"{cls_name}: Unexpected missing properties in Classic: {connect_only_properties}",
        )

    def check_missing_methods(
        self,
        classic_cls,
        connect_cls,
        cls_name,
        expected_missing_connect_methods,
        expected_missing_classic_methods,
    ):
        """Check for expected missing methods between classic and connect classes."""
        classic_methods = self.get_public_methods(classic_cls)
        connect_methods = self.get_public_methods(connect_cls)

        # Identify missing methods
        classic_only_methods = {
            name
            for name, method in classic_methods.items()
            if name not in connect_methods or method is None
        }
        connect_only_methods = set(connect_methods.keys()) - set(classic_methods.keys())

        # Compare the actual missing methods with the expected ones
        self.assertEqual(
            classic_only_methods,
            expected_missing_connect_methods,
            f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}",
        )

        # Reverse compatibility check
        self.assertEqual(
            connect_only_methods,
            expected_missing_classic_methods,
            f"{cls_name}: Unexpected missing methods in Classic: {connect_only_methods}",
        )

    def check_compatibility(
        self,
        classic_cls,
        connect_cls,
        cls_name,
        expected_missing_connect_properties,
        expected_missing_classic_properties,
        expected_missing_connect_methods,
        expected_missing_classic_methods,
    ):
        """
        Main method for checking compatibility between classic and connect.

        This method performs the following checks:
        - API signature comparison between classic and connect classes.
        - Property comparison, identifying any missing properties between classic and connect.
        - Method comparison, identifying any missing methods between classic and connect.

        Parameters
        ----------
        classic_cls : type
            The classic class to compare.
        connect_cls : type
            The connect class to compare.
        cls_name : str
            The name of the class.
        expected_missing_connect_properties : set
            A set of properties expected to be missing in the connect class.
        expected_missing_classic_properties : set
            A set of properties expected to be missing in the classic class.
        expected_missing_connect_methods : set
            A set of methods expected to be missing in the connect class.
        expected_missing_classic_methods : set
            A set of methods expected to be missing in the classic class.
        """
        self.compare_method_signatures(classic_cls, connect_cls, cls_name)
        self.compare_property_lists(
            classic_cls,
            connect_cls,
            cls_name,
            expected_missing_connect_properties,
            expected_missing_classic_properties,
        )
        self.check_missing_methods(
            classic_cls,
            connect_cls,
            cls_name,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_dataframe_compatibility(self):
        """Test DataFrame compatibility between classic and connect."""
        expected_missing_connect_properties = {"sql_ctx"}
        expected_missing_classic_properties = {"is_cached"}
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataFrame,
            ConnectDataFrame,
            "DataFrame",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_column_compatibility(self):
        """Test Column compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = {"to_plan"}
        self.check_compatibility(
            ClassicColumn,
            ConnectColumn,
            "Column",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_spark_session_compatibility(self):
        """Test SparkSession compatibility between classic and connect."""
        expected_missing_connect_properties = {"sparkContext"}
        expected_missing_classic_properties = {"is_stopped", "session_id"}
        expected_missing_connect_methods = {
            "clearProgressHandlers",
            "copyFromLocalToFs",
            "newSession",
            "registerProgressHandler",
            "removeProgressHandler",
        }
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicSparkSession,
            ConnectSparkSession,
            "SparkSession",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_catalog_compatibility(self):
        """Test Catalog compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicCatalog,
            ConnectCatalog,
            "Catalog",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_dataframe_reader_compatibility(self):
        """Test DataFrameReader compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataFrameReader,
            ConnectDataFrameReader,
            "DataFrameReader",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_dataframe_writer_compatibility(self):
        """Test DataFrameWriter compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataFrameWriter,
            ConnectDataFrameWriter,
            "DataFrameWriter",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_dataframe_writer_v2_compatibility(self):
        """Test DataFrameWriterV2 compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataFrameWriterV2,
            ConnectDataFrameWriterV2,
            "DataFrameWriterV2",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_window_compatibility(self):
        """Test Window compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicWindow,
            ConnectWindow,
            "Window",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_window_spec_compatibility(self):
        """Test WindowSpec compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicWindowSpec,
            ConnectWindowSpec,
            "WindowSpec",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_functions_compatibility(self):
        """Test Functions compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = {"check_dependencies"}
        self.check_compatibility(
            ClassicFunctions,
            ConnectFunctions,
            "Functions",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_grouping_compatibility(self):
        """Test Grouping compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicGroupedData,
            ConnectGroupedData,
            "Grouping",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_avro_compatibility(self):
        """Test Avro compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        # The current supported Avro functions are only `from_avro` and `to_avro`.
        # The missing methods belows are just util functions that imported to implement them.
        expected_missing_connect_methods = {
            "try_remote_avro_functions",
            "cast",
            "get_active_spark_context",
        }
        expected_missing_classic_methods = {"lit", "check_dependencies"}
        self.check_compatibility(
            ClassicAvro,
            ConnectAvro,
            "Avro",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_streaming_query_compatibility(self):
        """Test Streaming Query compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicStreamingQuery,
            ConnectStreamingQuery,
            "StreamingQuery",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_protobuf_compatibility(self):
        """Test Protobuf compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        # The current supported Avro functions are only `from_protobuf` and `to_protobuf`.
        # The missing methods belows are just util functions that imported to implement them.
        expected_missing_connect_methods = {
            "cast",
            "try_remote_protobuf_functions",
            "get_active_spark_context",
        }
        expected_missing_classic_methods = {"lit", "check_dependencies"}
        self.check_compatibility(
            ClassicProtobuf,
            ConnectProtobuf,
            "Protobuf",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_streaming_query_manager_compatibility(self):
        """Test Streaming Query Manager compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = {"close"}
        self.check_compatibility(
            ClassicStreamingQueryManager,
            ConnectStreamingQueryManager,
            "StreamingQueryManager",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_streaming_reader_compatibility(self):
        """Test Data Stream Reader compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataStreamReader,
            ConnectDataStreamReader,
            "DataStreamReader",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )

    def test_streaming_writer_compatibility(self):
        """Test Data Stream Writer compatibility between classic and connect."""
        expected_missing_connect_properties = set()
        expected_missing_classic_properties = set()
        expected_missing_connect_methods = set()
        expected_missing_classic_methods = set()
        self.check_compatibility(
            ClassicDataStreamWriter,
            ConnectDataStreamWriter,
            "DataStreamWriter",
            expected_missing_connect_properties,
            expected_missing_classic_properties,
            expected_missing_connect_methods,
            expected_missing_classic_methods,
        )


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
    pass


if __name__ == "__main__":
    from pyspark.sql.tests.test_connect_compatibility import *  # noqa: F401

    try:
        import xmlrunner  # type: ignore

        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
    except ImportError:
        testRunner = None
    unittest.main(testRunner=testRunner, verbosity=2)
