#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import warnings
from typing import (
    Dict,
    List,
    Sequence,
    Union,
    TYPE_CHECKING,
    Optional,
    overload,
    cast,
)

from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_vectorized_udf  # type: ignore[attr-defined]
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
from pyspark.sql.types import NumericType, StructType

import pyspark.sql.connect.plan as plan
from pyspark.sql.column import Column
from pyspark.sql.connect.functions import builtin as F
from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
from pyspark.sql.streaming.stateful_processor import StatefulProcessor

if TYPE_CHECKING:
    from pyspark.sql.connect._typing import (
        LiteralType,
        PandasGroupedMapFunction,
        GroupedMapPandasUserDefinedFunction,
        PandasCogroupedMapFunction,
        ArrowCogroupedMapFunction,
        ArrowGroupedMapFunction,
        PandasGroupedMapFunctionWithState,
    )
    from pyspark.sql.connect.dataframe import DataFrame
    from pyspark.sql.types import StructType


class GroupedData:
    def __init__(
        self,
        df: "DataFrame",
        group_type: str,
        grouping_cols: Sequence[Column],
        pivot_col: Optional[Column] = None,
        pivot_values: Optional[Sequence["LiteralType"]] = None,
        grouping_sets: Optional[Sequence[Sequence[Column]]] = None,
    ) -> None:
        from pyspark.sql.connect.dataframe import DataFrame

        assert isinstance(df, DataFrame)
        self._df = df

        assert isinstance(group_type, str) and group_type in [
            "groupby",
            "rollup",
            "cube",
            "pivot",
            "grouping_sets",
        ]
        self._group_type = group_type

        assert isinstance(grouping_cols, list) and all(isinstance(g, Column) for g in grouping_cols)
        self._grouping_cols: List[Column] = grouping_cols

        self._pivot_col: Optional["Column"] = None
        self._pivot_values: Optional[List["Column"]] = None
        if group_type == "pivot":
            assert pivot_col is not None and isinstance(pivot_col, Column)
            self._pivot_col = pivot_col

            if pivot_values is not None:
                assert isinstance(pivot_values, list)
                self._pivot_values = [F.lit(v) for v in pivot_values]

        self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None
        if group_type == "grouping_sets":
            assert grouping_sets is None or isinstance(grouping_sets, list)
            self._grouping_sets = grouping_sets

    def __repr__(self) -> str:
        # the expressions are not resolved here,
        # so the string representation can be different from classic PySpark.
        grouping_str = ", ".join(str(e._expr) for e in self._grouping_cols)
        grouping_str = f"grouping expressions: [{grouping_str}]"

        value_str = ", ".join("%s: %s" % c for c in self._df.dtypes)

        if self._group_type == "groupby":
            type_str = "GroupBy"
        elif self._group_type == "rollup":
            type_str = "RollUp"
        elif self._group_type == "cube":
            type_str = "Cube"
        elif self._group_type == "grouping_sets":
            type_str = "GroupingSets"
        else:
            type_str = "Pivot"

        return f"GroupedData[{grouping_str}, value: [{value_str}], type: {type_str}]"

    @overload
    def agg(self, *exprs: Column) -> "DataFrame":
        ...

    @overload
    def agg(self, __exprs: Dict[str, str]) -> "DataFrame":
        ...

    def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
        from pyspark.sql.connect.dataframe import DataFrame

        assert exprs, "exprs should not be empty"
        if len(exprs) == 1 and isinstance(exprs[0], dict):
            # Convert the dict into key value pairs
            aggregate_cols = [F._invoke_function(exprs[0][k], F.col(k)) for k in exprs[0]]
        else:
            # Columns
            assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
            aggregate_cols = cast(List[Column], list(exprs))

        return DataFrame(
            plan.Aggregate(
                child=self._df._plan,
                group_type=self._group_type,
                grouping_cols=self._grouping_cols,
                aggregate_cols=aggregate_cols,
                pivot_col=self._pivot_col,
                pivot_values=self._pivot_values,
                grouping_sets=self._grouping_sets,
            ),
            session=self._df._session,
        )

    agg.__doc__ = PySparkGroupedData.agg.__doc__

    def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
        from pyspark.sql.connect.dataframe import DataFrame
        from pyspark.sql.connect.types import verify_numeric_col_name

        assert isinstance(function, str) and function in ["min", "max", "avg", "sum"]

        assert isinstance(cols, list) and all(isinstance(c, str) for c in cols)

        schema = self._df.schema

        if len(cols) > 0:
            invalid_cols = [c for c in cols if not verify_numeric_col_name(c, schema)]
            if len(invalid_cols) > 0:
                raise PySparkTypeError(
                    errorClass="NOT_NUMERIC_COLUMNS",
                    messageParameters={"invalid_columns": str(invalid_cols)},
                )
            agg_cols = cols
        else:
            # if no column is provided, then all numerical columns are selected
            agg_cols = [
                field.name for field in schema.fields if isinstance(field.dataType, NumericType)
            ]

        return DataFrame(
            plan.Aggregate(
                child=self._df._plan,
                group_type=self._group_type,
                grouping_cols=self._grouping_cols,
                aggregate_cols=[F._invoke_function(function, F.col(c)) for c in agg_cols],
                pivot_col=self._pivot_col,
                pivot_values=self._pivot_values,
                grouping_sets=self._grouping_sets,
            ),
            session=self._df._session,
        )

    def min(self: "GroupedData", *cols: str) -> "DataFrame":
        return self._numeric_agg("min", list(cols))

    min.__doc__ = PySparkGroupedData.min.__doc__

    def max(self: "GroupedData", *cols: str) -> "DataFrame":
        return self._numeric_agg("max", list(cols))

    max.__doc__ = PySparkGroupedData.max.__doc__

    def sum(self: "GroupedData", *cols: str) -> "DataFrame":
        return self._numeric_agg("sum", list(cols))

    sum.__doc__ = PySparkGroupedData.sum.__doc__

    def avg(self: "GroupedData", *cols: str) -> "DataFrame":
        return self._numeric_agg("avg", list(cols))

    avg.__doc__ = PySparkGroupedData.avg.__doc__

    mean = avg

    def count(self: "GroupedData") -> "DataFrame":
        return self.agg(F._invoke_function("count", F.lit(1)).alias("count"))

    count.__doc__ = PySparkGroupedData.count.__doc__

    def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> "GroupedData":
        if self._group_type != "groupby":
            if self._group_type == "pivot":
                raise PySparkNotImplementedError(
                    errorClass="UNSUPPORTED_OPERATION",
                    messageParameters={"operation": "Repeated PIVOT operation"},
                )
            else:
                raise PySparkNotImplementedError(
                    errorClass="UNSUPPORTED_OPERATION",
                    messageParameters={"operation": f"PIVOT after {self._group_type.upper()}"},
                )

        if not isinstance(pivot_col, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "pivot_col", "arg_type": type(pivot_col).__name__},
            )

        if values is not None:
            if not isinstance(values, list):
                raise PySparkTypeError(
                    errorClass="NOT_LIST",
                    messageParameters={"arg_name": "values", "arg_type": type(values).__name__},
                )
            for v in values:
                if not isinstance(v, (bool, float, int, str)):
                    raise PySparkTypeError(
                        errorClass="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR",
                        messageParameters={"arg_name": "value", "arg_type": type(v).__name__},
                    )

        return GroupedData(
            df=self._df,
            group_type="pivot",
            grouping_cols=self._grouping_cols,
            pivot_col=self._df[pivot_col],
            pivot_values=values,
        )

    pivot.__doc__ = PySparkGroupedData.pivot.__doc__

    def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
        # Columns are special because hasattr always return True
        if (
            isinstance(udf, Column)
            or not hasattr(udf, "func")
            or (
                udf.evalType  # type: ignore[attr-defined]
                != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
            )
        ):
            raise PySparkTypeError(
                errorClass="INVALID_UDF_EVAL_TYPE",
                messageParameters={"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF"},
            )

        warnings.warn(
            "It is preferred to use 'applyInPandas' over this "
            "API. This API will be deprecated in the future releases. See SPARK-28264 for "
            "more details.",
            UserWarning,
        )

        return self.applyInPandas(udf.func, schema=udf.returnType)  # type: ignore[attr-defined]

    apply.__doc__ = PySparkGroupedData.apply.__doc__

    def applyInPandas(
        self, func: "PandasGroupedMapFunction", schema: Union["StructType", str]
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame
        from pyspark.sql.pandas.typehints import infer_group_pandas_eval_type_from_func

        # Try to infer the eval type from type hints
        eval_type = None
        try:
            eval_type = infer_group_pandas_eval_type_from_func(func)
        except Exception:
            warnings.warn("Cannot infer the eval type from type hints.", UserWarning)

        if eval_type is None:
            eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

        _validate_vectorized_udf(func, eval_type)
        if isinstance(schema, str):
            schema = cast(StructType, self._df._session._parse_ddl(schema))
        udf_obj = UserDefinedFunction(
            func,
            returnType=schema,
            evalType=eval_type,
        )

        res = DataFrame(
            plan.GroupMap(
                child=self._df._plan,
                grouping_cols=self._grouping_cols,
                function=udf_obj,
                cols=self._df.columns,
            ),
            session=self._df._session,
        )
        if isinstance(schema, StructType):
            res._cached_schema = schema
        return res

    applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__

    def applyInPandasWithState(
        self,
        func: "PandasGroupedMapFunctionWithState",
        outputStructType: Union[StructType, str],
        stateStructType: Union[StructType, str],
        outputMode: str,
        timeoutConf: str,
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame

        _validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE)
        udf_obj = UserDefinedFunction(
            func,
            returnType=outputStructType,
            evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
        )

        output_schema: str = (
            outputStructType.json()
            if isinstance(outputStructType, StructType)
            else outputStructType
        )

        state_schema: str = (
            stateStructType.json() if isinstance(stateStructType, StructType) else stateStructType
        )

        return DataFrame(
            plan.ApplyInPandasWithState(
                child=self._df._plan,
                grouping_cols=self._grouping_cols,
                function=udf_obj,
                output_schema=output_schema,
                state_schema=state_schema,
                output_mode=outputMode,
                timeout_conf=timeoutConf,
                cols=self._df.columns,
            ),
            session=self._df._session,
        )

    applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

    def transformWithStateInPandas(
        self,
        statefulProcessor: StatefulProcessor,
        outputStructType: Union[StructType, str],
        outputMode: str,
        timeMode: str,
        initialState: Optional["GroupedData"] = None,
        eventTimeColumnName: str = "",
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame
        from pyspark.sql.streaming.stateful_processor_util import (
            TransformWithStateInPandasUdfUtils,
        )

        udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
        if initialState is None:
            udf_obj = UserDefinedFunction(
                udf_util.transformWithStateUDF,
                returnType=outputStructType,
                evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
            )
            initial_state_plan = None
            initial_state_grouping_cols = None
        else:
            self._df._check_same_session(initialState._df)
            udf_obj = UserDefinedFunction(
                udf_util.transformWithStateWithInitStateUDF,
                returnType=outputStructType,
                evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
            )
            initial_state_plan = initialState._df._plan
            initial_state_grouping_cols = initialState._grouping_cols

        return DataFrame(
            plan.TransformWithStateInPandas(
                child=self._df._plan,
                grouping_cols=self._grouping_cols,
                function=udf_obj,
                output_schema=outputStructType,
                output_mode=outputMode,
                time_mode=timeMode,
                event_time_col_name=eventTimeColumnName,
                cols=self._df.columns,
                initial_state_plan=initial_state_plan,
                initial_state_grouping_cols=initial_state_grouping_cols,
            ),
            session=self._df._session,
        )

    transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__

    def transformWithState(
        self,
        statefulProcessor: StatefulProcessor,
        outputStructType: Union[StructType, str],
        outputMode: str,
        timeMode: str,
        initialState: Optional["GroupedData"] = None,
        eventTimeColumnName: str = "",
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame
        from pyspark.sql.streaming.stateful_processor_util import (
            TransformWithStateInPandasUdfUtils,
        )

        udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
        if initialState is None:
            udf_obj = UserDefinedFunction(
                udf_util.transformWithStateUDF,
                returnType=outputStructType,
                evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
            )
            initial_state_plan = None
            initial_state_grouping_cols = None
        else:
            self._df._check_same_session(initialState._df)
            udf_obj = UserDefinedFunction(
                udf_util.transformWithStateWithInitStateUDF,
                returnType=outputStructType,
                evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
            )
            initial_state_plan = initialState._df._plan
            initial_state_grouping_cols = initialState._grouping_cols

        return DataFrame(
            plan.TransformWithStateInPySpark(
                child=self._df._plan,
                grouping_cols=self._grouping_cols,
                function=udf_obj,
                output_schema=outputStructType,
                output_mode=outputMode,
                time_mode=timeMode,
                event_time_col_name=eventTimeColumnName,
                cols=self._df.columns,
                initial_state_plan=initial_state_plan,
                initial_state_grouping_cols=initial_state_grouping_cols,
            ),
            session=self._df._session,
        )

    transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__

    def applyInArrow(
        self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame

        try:
            # Try to infer the eval type from type hints
            eval_type = infer_group_arrow_eval_type_from_func(func)
        except Exception:
            warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)

        if eval_type is None:
            eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF

        _validate_vectorized_udf(func, eval_type)
        if isinstance(schema, str):
            schema = cast(StructType, self._df._session._parse_ddl(schema))
        udf_obj = UserDefinedFunction(
            func,
            returnType=schema,
            evalType=eval_type,
        )

        res = DataFrame(
            plan.GroupMap(
                child=self._df._plan,
                grouping_cols=self._grouping_cols,
                function=udf_obj,
                cols=self._df.columns,
            ),
            session=self._df._session,
        )
        if isinstance(schema, StructType):
            res._cached_schema = schema
        return res

    applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__

    def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
        return PandasCogroupedOps(self, other)

    cogroup.__doc__ = PySparkGroupedData.cogroup.__doc__


GroupedData.__doc__ = PySparkGroupedData.__doc__


class PandasCogroupedOps:
    def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
        gd1._df._check_same_session(gd2._df)
        self._gd1 = gd1
        self._gd2 = gd2

    def applyInPandas(
        self, func: "PandasCogroupedMapFunction", schema: Union["StructType", str]
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame

        _validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
        if isinstance(schema, str):
            schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
        udf_obj = UserDefinedFunction(
            func,
            returnType=schema,
            evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
        )

        res = DataFrame(
            plan.CoGroupMap(
                input=self._gd1._df._plan,
                input_grouping_cols=self._gd1._grouping_cols,
                other=self._gd2._df._plan,
                other_grouping_cols=self._gd2._grouping_cols,
                function=udf_obj,
            ),
            session=self._gd1._df._session,
        )
        if isinstance(schema, StructType):
            res._cached_schema = schema
        return res

    applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__

    def applyInArrow(
        self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str]
    ) -> "DataFrame":
        from pyspark.sql.connect.udf import UserDefinedFunction
        from pyspark.sql.connect.dataframe import DataFrame

        _validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
        if isinstance(schema, str):
            schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
        udf_obj = UserDefinedFunction(
            func,
            returnType=schema,
            evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
        )

        res = DataFrame(
            plan.CoGroupMap(
                input=self._gd1._df._plan,
                input_grouping_cols=self._gd1._grouping_cols,
                other=self._gd2._df._plan,
                other_grouping_cols=self._gd2._grouping_cols,
                function=udf_obj,
            ),
            session=self._gd1._df._session,
        )
        if isinstance(schema, StructType):
            res._cached_schema = schema
        return res

    applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__


PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__


def _test() -> None:
    import os
    import sys
    import doctest
    from pyspark.sql import SparkSession as PySparkSession
    import pyspark.sql.connect.group
    from pyspark.testing.utils import have_pandas, have_pyarrow

    globs = pyspark.sql.connect.group.__dict__.copy()

    if not have_pandas or not have_pyarrow:
        del pyspark.sql.connect.group.GroupedData.agg.__doc__

    globs["spark"] = (
        PySparkSession.builder.appName("sql.connect.group tests")
        .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
        .getOrCreate()
    )

    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.connect.group,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
    )

    globs["spark"].stop()

    if failure_count:
        sys.exit(-1)


if __name__ == "__main__":
    _test()
