[SPARK-53611][PYTHON] Limit Arrow batch sizes in window agg UDFs

### What changes were proposed in this pull request?
Limit Arrow batch sizes in window agg UDFs

### Why are the changes needed?
to avoid OOM in the JVM side, by batching the JVM->Python Arrow Batches

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #52608 from zhengruifeng/limit_win_agg.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 323aea3..89b3066 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1143,7 +1143,8 @@
         return "GroupArrowUDFSerializer"
 
 
-class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
+# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
+class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
     def __init__(
         self,
         timezone,
@@ -1183,7 +1184,7 @@
                 )
 
     def __repr__(self):
-        return "AggArrowUDFSerializer"
+        return "ArrowStreamAggArrowUDFSerializer"
 
 
 class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 1d30159..b3ed4c0 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -758,6 +758,52 @@
         )
         self.assertEqual(expected.collect(), result.collect())
 
+    def test_arrow_batch_slicing(self):
+        import pyarrow as pa
+
+        df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
+
+        w1 = Window.partitionBy("key").orderBy("v")
+        w2 = (
+            Window.partitionBy("key")
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+            .orderBy("v")
+        )
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_sum(v):
+            return pa.compute.sum(v)
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_sum_unbounded(v):
+            assert len(v) == 1000 / 2, len(v)
+            return pa.compute.sum(v)
+
+        expected1 = df.select("*", sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+        expected2 = df.select("*", sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+        for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result1 = (
+                        df.select("*", arrow_sum("v").over(w1).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected1, result1)
+
+                    result2 = (
+                        df.select("*", arrow_sum_unbounded("v").over(w2).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected2, result2)
+
 
 class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 2f534b8..fbc2b32 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -20,19 +20,8 @@
 from decimal import Decimal
 
 from pyspark.errors import AnalysisException, PythonException
-from pyspark.sql.functions import (
-    array,
-    explode,
-    col,
-    lit,
-    mean,
-    min,
-    max,
-    rank,
-    udf,
-    pandas_udf,
-    PandasUDFType,
-)
+from pyspark.sql import functions as sf
+from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
 from pyspark.sql.window import Window
 from pyspark.sql.types import (
     DecimalType,
@@ -64,10 +53,10 @@
         return (
             self.spark.range(10)
             .toDF("id")
-            .withColumn("vs", array([lit(i * 1.0) + col("id") for i in range(20, 30)]))
-            .withColumn("v", explode(col("vs")))
+            .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i in range(20, 30)]))
+            .withColumn("v", sf.explode(sf.col("vs")))
             .drop("vs")
-            .withColumn("w", lit(1.0))
+            .withColumn("w", sf.lit(1.0))
         )
 
     @property
@@ -172,10 +161,10 @@
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("mean_v", mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("mean_v", mean(df["v"]).over(w))
+        expected1 = df.withColumn("mean_v", sf.mean(df["v"]).over(w))
 
         result2 = df.select(mean_udf(df["v"]).over(w))
-        expected2 = df.select(mean(df["v"]).over(w))
+        expected2 = df.select(sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -191,9 +180,9 @@
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(df["v"]).over(w))
-            .withColumn("max_v", max(df["v"]).over(w))
-            .withColumn("min_w", min(df["w"]).over(w))
+            df.withColumn("mean_v", sf.mean(df["v"]).over(w))
+            .withColumn("max_v", sf.max(df["v"]).over(w))
+            .withColumn("min_w", sf.min(df["w"]).over(w))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -203,7 +192,7 @@
         w = self.unbounded_window
 
         result1 = df.withColumn("v", self.pandas_agg_mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v", mean(df["v"]).over(w))
+        expected1 = df.withColumn("v", sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
@@ -213,7 +202,7 @@
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v", mean_udf(df["v"] * 2).over(w) + 1)
-        expected1 = df.withColumn("v", mean(df["v"] * 2).over(w) + 1)
+        expected1 = df.withColumn("v", sf.mean(df["v"] * 2).over(w) + 1)
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
@@ -226,10 +215,10 @@
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v2", plus_one(mean_udf(plus_one(df["v"])).over(w)))
-        expected1 = df.withColumn("v2", plus_one(mean(plus_one(df["v"])).over(w)))
+        expected1 = df.withColumn("v2", plus_one(sf.mean(plus_one(df["v"])).over(w)))
 
         result2 = df.withColumn("v2", time_two(mean_udf(time_two(df["v"])).over(w)))
-        expected2 = df.withColumn("v2", time_two(mean(time_two(df["v"])).over(w)))
+        expected2 = df.withColumn("v2", time_two(sf.mean(time_two(df["v"])).over(w)))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -240,10 +229,10 @@
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v2", mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v2", mean(df["v"]).over(w))
+        expected1 = df.withColumn("v2", sf.mean(df["v"]).over(w))
 
         result2 = df.select(mean_udf(df["v"]).over(w))
-        expected2 = df.select(mean(df["v"]).over(w))
+        expected2 = df.select(sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -256,26 +245,28 @@
         min_udf = self.pandas_agg_min_udf
 
         result1 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - min_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v_diff", max(df["v"]).over(w) - min(df["v"]).over(w))
+        expected1 = df.withColumn("v_diff", sf.max(df["v"]).over(w) - sf.min(df["v"]).over(w))
 
         # Test mixing sql window function and window udf in the same expression
-        result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - min(df["v"]).over(w))
+        result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - sf.min(df["v"]).over(w))
         expected2 = expected1
 
         # Test chaining sql aggregate function and udf
         result3 = (
             df.withColumn("max_v", max_udf(df["v"]).over(w))
-            .withColumn("min_v", min(df["v"]).over(w))
-            .withColumn("v_diff", col("max_v") - col("min_v"))
+            .withColumn("min_v", sf.min(df["v"]).over(w))
+            .withColumn("v_diff", sf.col("max_v") - sf.col("min_v"))
             .drop("max_v", "min_v")
         )
         expected3 = expected1
 
         # Test mixing sql window function and udf
         result4 = df.withColumn("max_v", max_udf(df["v"]).over(w)).withColumn(
-            "rank", rank().over(ow)
+            "rank", sf.rank().over(ow)
         )
-        expected4 = df.withColumn("max_v", max(df["v"]).over(w)).withColumn("rank", rank().over(ow))
+        expected4 = df.withColumn("max_v", sf.max(df["v"]).over(w)).withColumn(
+            "rank", sf.rank().over(ow)
+        )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -303,8 +294,6 @@
             df.withColumn("v2", foo_udf(df["v"]).over(w)).schema
 
     def test_bounded_simple(self):
-        from pyspark.sql.functions import mean, max, min, count
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.shrinking_range_window
@@ -323,17 +312,15 @@
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(plus_one(df["v"])).over(w1))
-            .withColumn("count_v", count(df["v"]).over(w2))
-            .withColumn("max_v", max(df["v"]).over(w2))
-            .withColumn("min_v", min(df["v"]).over(w1))
+            df.withColumn("mean_v", sf.mean(plus_one(df["v"])).over(w1))
+            .withColumn("count_v", sf.count(df["v"]).over(w2))
+            .withColumn("max_v", sf.max(df["v"]).over(w2))
+            .withColumn("min_v", sf.min(df["v"]).over(w1))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_growing_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.growing_row_window
         w2 = self.growing_range_window
@@ -344,15 +331,13 @@
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_sliding_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.sliding_range_window
@@ -363,15 +348,13 @@
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_shrinking_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.shrinking_row_window
         w2 = self.shrinking_range_window
@@ -382,15 +365,13 @@
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_bounded_mixed(self):
-        from pyspark.sql.functions import mean, max
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.unbounded_window
@@ -405,9 +386,9 @@
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(df["v"]).over(w1))
-            .withColumn("max_v", max(df["v"]).over(w2))
-            .withColumn("mean_unbounded_v", mean(df["v"]).over(w1))
+            df.withColumn("mean_v", sf.mean(df["v"]).over(w1))
+            .withColumn("max_v", sf.max(df["v"]).over(w2))
+            .withColumn("mean_unbounded_v", sf.mean(df["v"]).over(w1))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -425,7 +406,7 @@
                 ]
             ):
                 with self.subTest(bound=bound, query_no=i):
-                    assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))
+                    assertDataFrameEqual(windowed, df.withColumn("wm", sf.mean(df.v).over(w)))
 
         with self.tempView("v"):
             df.createOrReplaceTempView("v")
@@ -521,7 +502,7 @@
                 ]
             ):
                 with self.subTest(bound=bound, query_no=i):
-                    assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))
+                    assertDataFrameEqual(windowed, df.withColumn("wm", sf.mean(df.v).over(w)))
 
         with self.tempView("v"):
             df.createOrReplaceTempView("v")
@@ -608,6 +589,50 @@
                 result = df.select(mean_udf(df["v"]).over(w)).first()[0]
                 assert result == 123
 
+    def test_arrow_batch_slicing(self):
+        df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
+
+        w1 = Window.partitionBy("key").orderBy("v")
+        w2 = (
+            Window.partitionBy("key")
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+            .orderBy("v")
+        )
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_sum(v):
+            return v.sum()
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_sum_unbounded(v):
+            assert len(v) == 1000 / 2, len(v)
+            return v.sum()
+
+        expected1 = df.select("*", sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+        expected2 = df.select("*", sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+        for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result1 = (
+                        df.select("*", pandas_sum("v").over(w1).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected1, result1)
+
+                    result2 = (
+                        df.select("*", pandas_sum_unbounded("v").over(w2).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected2, result2)
+
 
 class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c3ba8bc..d94ba8f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -52,7 +52,6 @@
 from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion
 from pyspark.sql.functions import SkipRestOfInputTableException
 from pyspark.sql.pandas.serializers import (
-    AggArrowUDFSerializer,
     ArrowStreamPandasUDFSerializer,
     ArrowStreamPandasUDTFSerializer,
     GroupPandasUDFSerializer,
@@ -67,6 +66,7 @@
     TransformWithStateInPySparkRowSerializer,
     TransformWithStateInPySparkRowInitStateSerializer,
     ArrowStreamArrowUDFSerializer,
+    ArrowStreamAggArrowUDFSerializer,
     ArrowBatchUDFSerializer,
     ArrowStreamUDTFSerializer,
     ArrowStreamArrowUDTFSerializer,
@@ -2612,11 +2612,15 @@
             or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
         ):
             ser = GroupArrowUDFSerializer(_assign_cols_by_name)
-        elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
-            ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
+        elif eval_type in (
+            PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+            PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
+        ):
+            ser = ArrowStreamAggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
         elif eval_type in (
             PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
             PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+            PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         ):
             ser = GroupPandasUDFSerializer(
                 timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled
@@ -2703,7 +2707,6 @@
         elif eval_type in (
             PythonEvalType.SQL_SCALAR_ARROW_UDF,
             PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
-            PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
         ):
             # Arrow cast and safe check are always enabled
             ser = ArrowStreamArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 1643a8d..82c03b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -368,7 +368,7 @@
         }
       }
 
-      val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
+      val runner = new ArrowPythonWithNamedArgumentRunner(
         pyFuncs,
         evalType,
         argMetas,
@@ -378,7 +378,9 @@
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
-        profiler).compute(pythonInput, context.partitionId(), context)
+        profiler) with GroupedPythonArrowInput
+
+      val windowFunctionResult = runner.compute(pythonInput, context.partitionId(), context)
 
       val joined = new JoinedRow