[SPARK-49810][PYTHON] Extract the preparation of `DataFrame.sort` to parent class
### What changes were proposed in this pull request?
Extract the preparation of df.sort to parent class
### Why are the changes needed?
deduplicate code, the logics in two classes are similar
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48282 from zhengruifeng/py_sql_sort.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py
index 0dd66a9..9f9dedb 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -55,6 +55,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.column import Column
+from pyspark.sql.functions import builtin as F
from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.merge import MergeIntoWriter
@@ -873,7 +874,8 @@
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+ jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column))
return DataFrame(jdf, self.sparkSession)
def sort(
@@ -881,7 +883,8 @@
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+ jdf = self._jdf.sort(self._jseq(_cols, _to_java_column))
return DataFrame(jdf, self.sparkSession)
orderBy = sort
@@ -928,51 +931,6 @@
_cols.append(c) # type: ignore[arg-type]
return self._jseq(_cols, _to_java_column)
- def _sort_cols(
- self,
- cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
- kwargs: Dict[str, Any],
- ) -> "JavaObject":
- """Return a JVM Seq of Columns that describes the sort order"""
- if not cols:
- raise PySparkValueError(
- errorClass="CANNOT_BE_EMPTY",
- messageParameters={"item": "column"},
- )
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
-
- jcols = []
- for c in cols:
- if isinstance(c, int) and not isinstance(c, bool):
- # ordinal is 1-based
- if c > 0:
- _c = self[c - 1]
- # negative ordinal means sort by desc
- elif c < 0:
- _c = self[-c - 1].desc()
- else:
- raise PySparkIndexError(
- errorClass="ZERO_INDEX",
- messageParameters={},
- )
- else:
- _c = c # type: ignore[assignment]
- jcols.append(_to_java_column(cast("ColumnOrName", _c)))
-
- ascending = kwargs.get("ascending", True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- jcols = [jc.desc() for jc in jcols]
- elif isinstance(ascending, list):
- jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)]
- else:
- raise PySparkTypeError(
- errorClass="NOT_BOOL_OR_LIST",
- messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
- )
- return self._jseq(jcols)
-
def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 146cfe1..136fe60 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -739,62 +739,16 @@
def tail(self, num: int) -> List[Row]:
return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect()
- def _sort_cols(
- self,
- cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
- kwargs: Dict[str, Any],
- ) -> List[Column]:
- """Return a JVM Seq of Columns that describes the sort order"""
- if cols is None:
- raise PySparkValueError(
- errorClass="CANNOT_BE_EMPTY",
- messageParameters={"item": "cols"},
- )
-
- if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
-
- _cols: List[Column] = []
- for c in cols:
- if isinstance(c, int) and not isinstance(c, bool):
- # ordinal is 1-based
- if c > 0:
- _c = self[c - 1]
- # negative ordinal means sort by desc
- elif c < 0:
- _c = self[-c - 1].desc()
- else:
- raise PySparkIndexError(
- errorClass="ZERO_INDEX",
- messageParameters={},
- )
- else:
- _c = c # type: ignore[assignment]
- _cols.append(F._to_col(cast("ColumnOrName", _c)))
-
- ascending = kwargs.get("ascending", True)
- if isinstance(ascending, (bool, int)):
- if not ascending:
- _cols = [c.desc() for c in _cols]
- elif isinstance(ascending, list):
- _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)]
- else:
- raise PySparkTypeError(
- errorClass="NOT_BOOL_OR_LIST",
- messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
- )
-
- return [F._sort_col(c) for c in _cols]
-
def sort(
self,
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
res = DataFrame(
plan.Sort(
self._plan,
- columns=self._sort_cols(cols, kwargs),
+ columns=[F._sort_col(c) for c in _cols],
is_global=True,
),
session=self._session,
@@ -809,10 +763,11 @@
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
+ _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
res = DataFrame(
plan.Sort(
self._plan,
- columns=self._sort_cols(cols, kwargs),
+ columns=[F._sort_col(c) for c in _cols],
is_global=False,
),
session=self._session,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1420345..5906108 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2891,6 +2891,62 @@
"""
...
+ def _preapare_cols_for_sort(
+ self,
+ _to_col: Callable[[str], Column],
+ cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+ kwargs: Dict[str, Any],
+ ) -> Sequence[Column]:
+ from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkIndexError
+
+ if not cols:
+ raise PySparkValueError(
+ errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "cols"}
+ )
+
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
+
+ _cols: List[Column] = []
+ for c in cols:
+ if isinstance(c, int) and not isinstance(c, bool):
+ # ordinal is 1-based
+ if c > 0:
+ _cols.append(self[c - 1])
+ # negative ordinal means sort by desc
+ elif c < 0:
+ _cols.append(self[-c - 1].desc())
+ else:
+ raise PySparkIndexError(
+ errorClass="ZERO_INDEX",
+ messageParameters={},
+ )
+ elif isinstance(c, Column):
+ _cols.append(c)
+ elif isinstance(c, str):
+ _cols.append(_to_col(c))
+ else:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_INT_OR_STR",
+ messageParameters={
+ "arg_name": "col",
+ "arg_type": type(c).__name__,
+ },
+ )
+
+ ascending = kwargs.get("ascending", True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ _cols = [c.desc() for c in _cols]
+ elif isinstance(ascending, list):
+ _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)]
+ else:
+ raise PySparkTypeError(
+ errorClass="NOT_COLUMN_OR_INT_OR_STR",
+ messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
+ )
+ return _cols
+
orderBy = sort
@dispatch_df_method