[SPARK-49810][PYTHON] Extract the preparation of `DataFrame.sort` to parent class

### What changes were proposed in this pull request?
Extract the preparation of df.sort to parent class

### Why are the changes needed?
deduplicate code, the logics in two classes are similar

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48282 from zhengruifeng/py_sql_sort.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py
index 0dd66a9..9f9dedb 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -55,6 +55,7 @@
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
 from pyspark.sql.column import Column
+from pyspark.sql.functions import builtin as F
 from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.merge import MergeIntoWriter
@@ -873,7 +874,8 @@
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+        jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column))
         return DataFrame(jdf, self.sparkSession)
 
     def sort(
@@ -881,7 +883,8 @@
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
+        jdf = self._jdf.sort(self._jseq(_cols, _to_java_column))
         return DataFrame(jdf, self.sparkSession)
 
     orderBy = sort
@@ -928,51 +931,6 @@
                 _cols.append(c)  # type: ignore[arg-type]
         return self._jseq(_cols, _to_java_column)
 
-    def _sort_cols(
-        self,
-        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
-        kwargs: Dict[str, Any],
-    ) -> "JavaObject":
-        """Return a JVM Seq of Columns that describes the sort order"""
-        if not cols:
-            raise PySparkValueError(
-                errorClass="CANNOT_BE_EMPTY",
-                messageParameters={"item": "column"},
-            )
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-
-        jcols = []
-        for c in cols:
-            if isinstance(c, int) and not isinstance(c, bool):
-                # ordinal is 1-based
-                if c > 0:
-                    _c = self[c - 1]
-                # negative ordinal means sort by desc
-                elif c < 0:
-                    _c = self[-c - 1].desc()
-                else:
-                    raise PySparkIndexError(
-                        errorClass="ZERO_INDEX",
-                        messageParameters={},
-                    )
-            else:
-                _c = c  # type: ignore[assignment]
-            jcols.append(_to_java_column(cast("ColumnOrName", _c)))
-
-        ascending = kwargs.get("ascending", True)
-        if isinstance(ascending, (bool, int)):
-            if not ascending:
-                jcols = [jc.desc() for jc in jcols]
-        elif isinstance(ascending, list):
-            jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)]
-        else:
-            raise PySparkTypeError(
-                errorClass="NOT_BOOL_OR_LIST",
-                messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
-            )
-        return self._jseq(jcols)
-
     def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]  # type: ignore[assignment]
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 146cfe1..136fe60 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -739,62 +739,16 @@
     def tail(self, num: int) -> List[Row]:
         return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect()
 
-    def _sort_cols(
-        self,
-        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
-        kwargs: Dict[str, Any],
-    ) -> List[Column]:
-        """Return a JVM Seq of Columns that describes the sort order"""
-        if cols is None:
-            raise PySparkValueError(
-                errorClass="CANNOT_BE_EMPTY",
-                messageParameters={"item": "cols"},
-            )
-
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-
-        _cols: List[Column] = []
-        for c in cols:
-            if isinstance(c, int) and not isinstance(c, bool):
-                # ordinal is 1-based
-                if c > 0:
-                    _c = self[c - 1]
-                # negative ordinal means sort by desc
-                elif c < 0:
-                    _c = self[-c - 1].desc()
-                else:
-                    raise PySparkIndexError(
-                        errorClass="ZERO_INDEX",
-                        messageParameters={},
-                    )
-            else:
-                _c = c  # type: ignore[assignment]
-            _cols.append(F._to_col(cast("ColumnOrName", _c)))
-
-        ascending = kwargs.get("ascending", True)
-        if isinstance(ascending, (bool, int)):
-            if not ascending:
-                _cols = [c.desc() for c in _cols]
-        elif isinstance(ascending, list):
-            _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)]
-        else:
-            raise PySparkTypeError(
-                errorClass="NOT_BOOL_OR_LIST",
-                messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
-            )
-
-        return [F._sort_col(c) for c in _cols]
-
     def sort(
         self,
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
         res = DataFrame(
             plan.Sort(
                 self._plan,
-                columns=self._sort_cols(cols, kwargs),
+                columns=[F._sort_col(c) for c in _cols],
                 is_global=True,
             ),
             session=self._session,
@@ -809,10 +763,11 @@
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
+        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
         res = DataFrame(
             plan.Sort(
                 self._plan,
-                columns=self._sort_cols(cols, kwargs),
+                columns=[F._sort_col(c) for c in _cols],
                 is_global=False,
             ),
             session=self._session,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1420345..5906108 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2891,6 +2891,62 @@
         """
         ...
 
+    def _preapare_cols_for_sort(
+        self,
+        _to_col: Callable[[str], Column],
+        cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]],
+        kwargs: Dict[str, Any],
+    ) -> Sequence[Column]:
+        from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkIndexError
+
+        if not cols:
+            raise PySparkValueError(
+                errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "cols"}
+            )
+
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+
+        _cols: List[Column] = []
+        for c in cols:
+            if isinstance(c, int) and not isinstance(c, bool):
+                # ordinal is 1-based
+                if c > 0:
+                    _cols.append(self[c - 1])
+                # negative ordinal means sort by desc
+                elif c < 0:
+                    _cols.append(self[-c - 1].desc())
+                else:
+                    raise PySparkIndexError(
+                        errorClass="ZERO_INDEX",
+                        messageParameters={},
+                    )
+            elif isinstance(c, Column):
+                _cols.append(c)
+            elif isinstance(c, str):
+                _cols.append(_to_col(c))
+            else:
+                raise PySparkTypeError(
+                    errorClass="NOT_COLUMN_OR_INT_OR_STR",
+                    messageParameters={
+                        "arg_name": "col",
+                        "arg_type": type(c).__name__,
+                    },
+                )
+
+        ascending = kwargs.get("ascending", True)
+        if isinstance(ascending, (bool, int)):
+            if not ascending:
+                _cols = [c.desc() for c in _cols]
+        elif isinstance(ascending, list):
+            _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)]
+        else:
+            raise PySparkTypeError(
+                errorClass="NOT_COLUMN_OR_INT_OR_STR",
+                messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__},
+            )
+        return _cols
+
     orderBy = sort
 
     @dispatch_df_method