blob: 4bba2b1e0b71752da70dade92ede1cf32e75bee9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import warnings
from typing import (
Dict,
List,
Sequence,
Union,
TYPE_CHECKING,
Optional,
overload,
cast,
)
from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined]
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
from pyspark.sql.types import NumericType, StructType
import pyspark.sql.connect.plan as plan
from pyspark.sql.column import Column
from pyspark.sql.connect.functions import builtin as F
from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
from pyspark.sql.streaming.stateful_processor import StatefulProcessor
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
LiteralType,
PandasGroupedMapFunction,
GroupedMapPandasUserDefinedFunction,
PandasCogroupedMapFunction,
ArrowCogroupedMapFunction,
ArrowGroupedMapFunction,
PandasGroupedMapFunctionWithState,
)
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.types import StructType
class GroupedData:
def __init__(
self,
df: "DataFrame",
group_type: str,
grouping_cols: Sequence[Column],
pivot_col: Optional[Column] = None,
pivot_values: Optional[Sequence["LiteralType"]] = None,
grouping_sets: Optional[Sequence[Sequence[Column]]] = None,
) -> None:
from pyspark.sql.connect.dataframe import DataFrame
assert isinstance(df, DataFrame)
self._df = df
assert isinstance(group_type, str) and group_type in [
"groupby",
"rollup",
"cube",
"pivot",
"grouping_sets",
]
self._group_type = group_type
assert isinstance(grouping_cols, list) and all(isinstance(g, Column) for g in grouping_cols)
self._grouping_cols: List[Column] = grouping_cols
self._pivot_col: Optional["Column"] = None
self._pivot_values: Optional[List["Column"]] = None
if group_type == "pivot":
assert pivot_col is not None and isinstance(pivot_col, Column)
self._pivot_col = pivot_col
if pivot_values is not None:
assert isinstance(pivot_values, list)
self._pivot_values = [F.lit(v) for v in pivot_values]
self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None
if group_type == "grouping_sets":
assert grouping_sets is None or isinstance(grouping_sets, list)
self._grouping_sets = grouping_sets
def __repr__(self) -> str:
# the expressions are not resolved here,
# so the string representation can be different from classic PySpark.
grouping_str = ", ".join(str(e._expr) for e in self._grouping_cols)
grouping_str = f"grouping expressions: [{grouping_str}]"
value_str = ", ".join("%s: %s" % c for c in self._df.dtypes)
if self._group_type == "groupby":
type_str = "GroupBy"
elif self._group_type == "rollup":
type_str = "RollUp"
elif self._group_type == "cube":
type_str = "Cube"
elif self._group_type == "grouping_sets":
type_str = "GroupingSets"
else:
type_str = "Pivot"
return f"GroupedData[{grouping_str}, value: [{value_str}], type: {type_str}]"
@overload
def agg(self, *exprs: Column) -> "DataFrame":
...
@overload
def agg(self, __exprs: Dict[str, str]) -> "DataFrame":
...
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
from pyspark.sql.connect.dataframe import DataFrame
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
# Convert the dict into key value pairs
aggregate_cols = [F._invoke_function(exprs[0][k], F.col(k)) for k in exprs[0]]
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
aggregate_cols = cast(List[Column], list(exprs))
return DataFrame(
plan.Aggregate(
child=self._df._plan,
group_type=self._group_type,
grouping_cols=self._grouping_cols,
aggregate_cols=aggregate_cols,
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
agg.__doc__ = PySparkGroupedData.agg.__doc__
def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
from pyspark.sql.connect.dataframe import DataFrame
assert isinstance(function, str) and function in ["min", "max", "avg", "sum"]
assert isinstance(cols, list) and all(isinstance(c, str) for c in cols)
schema = self._df.schema
numerical_cols: List[str] = [
field.name for field in schema.fields if isinstance(field.dataType, NumericType)
]
if len(cols) > 0:
invalid_cols = [c for c in cols if c not in numerical_cols]
if len(invalid_cols) > 0:
raise PySparkTypeError(
errorClass="NOT_NUMERIC_COLUMNS",
messageParameters={"invalid_columns": str(invalid_cols)},
)
agg_cols = cols
else:
# if no column is provided, then all numerical columns are selected
agg_cols = numerical_cols
return DataFrame(
plan.Aggregate(
child=self._df._plan,
group_type=self._group_type,
grouping_cols=self._grouping_cols,
aggregate_cols=[F._invoke_function(function, F.col(c)) for c in agg_cols],
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
def min(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("min", list(cols))
min.__doc__ = PySparkGroupedData.min.__doc__
def max(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("max", list(cols))
max.__doc__ = PySparkGroupedData.max.__doc__
def sum(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("sum", list(cols))
sum.__doc__ = PySparkGroupedData.sum.__doc__
def avg(self: "GroupedData", *cols: str) -> "DataFrame":
return self._numeric_agg("avg", list(cols))
avg.__doc__ = PySparkGroupedData.avg.__doc__
mean = avg
def count(self: "GroupedData") -> "DataFrame":
return self.agg(F._invoke_function("count", F.lit(1)).alias("count"))
count.__doc__ = PySparkGroupedData.count.__doc__
def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> "GroupedData":
if self._group_type != "groupby":
if self._group_type == "pivot":
raise PySparkNotImplementedError(
errorClass="UNSUPPORTED_OPERATION",
messageParameters={"operation": "Repeated PIVOT operation"},
)
else:
raise PySparkNotImplementedError(
errorClass="UNSUPPORTED_OPERATION",
messageParameters={"operation": f"PIVOT after {self._group_type.upper()}"},
)
if not isinstance(pivot_col, str):
raise PySparkTypeError(
errorClass="NOT_STR",
messageParameters={"arg_name": "pivot_col", "arg_type": type(pivot_col).__name__},
)
if values is not None:
if not isinstance(values, list):
raise PySparkTypeError(
errorClass="NOT_LIST",
messageParameters={"arg_name": "values", "arg_type": type(values).__name__},
)
for v in values:
if not isinstance(v, (bool, float, int, str)):
raise PySparkTypeError(
errorClass="NOT_BOOL_OR_FLOAT_OR_INT_OR_STR",
messageParameters={"arg_name": "value", "arg_type": type(v).__name__},
)
return GroupedData(
df=self._df,
group_type="pivot",
grouping_cols=self._grouping_cols,
pivot_col=self._df[pivot_col],
pivot_values=values,
)
pivot.__doc__ = PySparkGroupedData.pivot.__doc__
def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
# Columns are special because hasattr always return True
if (
isinstance(udf, Column)
or not hasattr(udf, "func")
or (
udf.evalType # type: ignore[attr-defined]
!= PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
)
):
raise PySparkTypeError(
errorClass="INVALID_UDF_EVAL_TYPE",
messageParameters={"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF"},
)
warnings.warn(
"It is preferred to use 'applyInPandas' over this "
"API. This API will be deprecated in the future releases. See SPARK-28264 for "
"more details.",
UserWarning,
)
return self.applyInPandas(udf.func, schema=udf.returnType) # type: ignore[attr-defined]
apply.__doc__ = PySparkGroupedData.apply.__doc__
def applyInPandas(
self, func: "PandasGroupedMapFunction", schema: Union["StructType", str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
_validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
)
res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
cols=self._df.columns,
),
session=self._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res
applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__
def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
_validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE)
udf_obj = UserDefinedFunction(
func,
returnType=outputStructType,
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)
output_schema: str = (
outputStructType.json()
if isinstance(outputStructType, StructType)
else outputStructType
)
state_schema: str = (
stateStructType.json() if isinstance(stateStructType, StructType) else stateStructType
)
return DataFrame(
plan.ApplyInPandasWithState(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=output_schema,
state_schema=state_schema,
output_mode=outputMode,
timeout_conf=timeoutConf,
cols=self._df.columns,
),
session=self._df._session,
)
applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__
def transformWithStateInPandas(
self,
statefulProcessor: StatefulProcessor,
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
eventTimeColumnName: str = "",
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasUdfUtils,
)
udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
if initialState is None:
udf_obj = UserDefinedFunction(
udf_util.transformWithStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
initial_state_plan = None
initial_state_grouping_cols = None
else:
self._df._check_same_session(initialState._df)
udf_obj = UserDefinedFunction(
udf_util.transformWithStateWithInitStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
)
initial_state_plan = initialState._df._plan
initial_state_grouping_cols = initialState._grouping_cols
return DataFrame(
plan.TransformWithStateInPandas(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=outputStructType,
output_mode=outputMode,
time_mode=timeMode,
event_time_col_name=eventTimeColumnName,
cols=self._df.columns,
initial_state_plan=initial_state_plan,
initial_state_grouping_cols=initial_state_grouping_cols,
),
session=self._df._session,
)
transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__
def transformWithState(
self,
statefulProcessor: StatefulProcessor,
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
eventTimeColumnName: str = "",
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasUdfUtils,
)
udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
if initialState is None:
udf_obj = UserDefinedFunction(
udf_util.transformWithStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
)
initial_state_plan = None
initial_state_grouping_cols = None
else:
self._df._check_same_session(initialState._df)
udf_obj = UserDefinedFunction(
udf_util.transformWithStateWithInitStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
)
initial_state_plan = initialState._df._plan
initial_state_grouping_cols = initialState._grouping_cols
return DataFrame(
plan.TransformWithStateInPySpark(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=outputStructType,
output_mode=outputMode,
time_mode=timeMode,
event_time_col_name=eventTimeColumnName,
cols=self._df.columns,
initial_state_plan=initial_state_plan,
initial_state_grouping_cols=initial_state_grouping_cols,
),
session=self._df._session,
)
transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__
def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
try:
# Try to infer the eval type from type hints
eval_type = infer_group_arrow_eval_type_from_func(func)
except Exception:
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)
if eval_type is None:
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
_validate_vectorized_udf(func, eval_type)
if isinstance(schema, str):
schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=eval_type,
)
res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
cols=self._df.columns,
),
session=self._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res
applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__
def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
return PandasCogroupedOps(self, other)
cogroup.__doc__ = PySparkGroupedData.cogroup.__doc__
GroupedData.__doc__ = PySparkGroupedData.__doc__
class PandasCogroupedOps:
def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
gd1._df._check_same_session(gd2._df)
self._gd1 = gd1
self._gd2 = gd2
def applyInPandas(
self, func: "PandasCogroupedMapFunction", schema: Union["StructType", str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
_validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
)
res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
other=self._gd2._df._plan,
other_grouping_cols=self._gd2._grouping_cols,
function=udf_obj,
),
session=self._gd1._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res
applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__
def applyInArrow(
self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
_validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
if isinstance(schema, str):
schema = cast(StructType, self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
)
res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
other=self._gd2._df._plan,
other_grouping_cols=self._gd2._grouping_cols,
function=udf_obj,
),
session=self._gd1._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res
applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__
PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__
def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.group
from pyspark.testing.utils import have_pandas, have_pyarrow
globs = pyspark.sql.connect.group.__dict__.copy()
if not have_pandas or not have_pyarrow:
del pyspark.sql.connect.group.GroupedData.agg.__doc__
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.group tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.group,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
globs["spark"].stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()