[SPARK-53611][PYTHON] Limit Arrow batch sizes in window agg UDFs
### What changes were proposed in this pull request?
Limit Arrow batch sizes in window agg UDFs
### Why are the changes needed?
to avoid OOM in the JVM side, by batching the JVM->Python Arrow Batches
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52608 from zhengruifeng/limit_win_agg.
Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 323aea3..89b3066 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1143,7 +1143,8 @@
return "GroupArrowUDFSerializer"
-class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
+# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
+class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
def __init__(
self,
timezone,
@@ -1183,7 +1184,7 @@
)
def __repr__(self):
- return "AggArrowUDFSerializer"
+ return "ArrowStreamAggArrowUDFSerializer"
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 1d30159..b3ed4c0 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -758,6 +758,52 @@
)
self.assertEqual(expected.collect(), result.collect())
+ def test_arrow_batch_slicing(self):
+ import pyarrow as pa
+
+ df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
+
+ w1 = Window.partitionBy("key").orderBy("v")
+ w2 = (
+ Window.partitionBy("key")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_sum(v):
+ return pa.compute.sum(v)
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_sum_unbounded(v):
+ assert len(v) == 1000 / 2, len(v)
+ return pa.compute.sum(v)
+
+ expected1 = df.select("*", sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+ expected2 = df.select("*", sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+ for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result1 = (
+ df.select("*", arrow_sum("v").over(w1).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected1, result1)
+
+ result2 = (
+ df.select("*", arrow_sum_unbounded("v").over(w2).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected2, result2)
+
class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 2f534b8..fbc2b32 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -20,19 +20,8 @@
from decimal import Decimal
from pyspark.errors import AnalysisException, PythonException
-from pyspark.sql.functions import (
- array,
- explode,
- col,
- lit,
- mean,
- min,
- max,
- rank,
- udf,
- pandas_udf,
- PandasUDFType,
-)
+from pyspark.sql import functions as sf
+from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.window import Window
from pyspark.sql.types import (
DecimalType,
@@ -64,10 +53,10 @@
return (
self.spark.range(10)
.toDF("id")
- .withColumn("vs", array([lit(i * 1.0) + col("id") for i in range(20, 30)]))
- .withColumn("v", explode(col("vs")))
+ .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i in range(20, 30)]))
+ .withColumn("v", sf.explode(sf.col("vs")))
.drop("vs")
- .withColumn("w", lit(1.0))
+ .withColumn("w", sf.lit(1.0))
)
@property
@@ -172,10 +161,10 @@
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("mean_v", mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("mean_v", mean(df["v"]).over(w))
+ expected1 = df.withColumn("mean_v", sf.mean(df["v"]).over(w))
result2 = df.select(mean_udf(df["v"]).over(w))
- expected2 = df.select(mean(df["v"]).over(w))
+ expected2 = df.select(sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -191,9 +180,9 @@
)
expected1 = (
- df.withColumn("mean_v", mean(df["v"]).over(w))
- .withColumn("max_v", max(df["v"]).over(w))
- .withColumn("min_w", min(df["w"]).over(w))
+ df.withColumn("mean_v", sf.mean(df["v"]).over(w))
+ .withColumn("max_v", sf.max(df["v"]).over(w))
+ .withColumn("min_w", sf.min(df["w"]).over(w))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -203,7 +192,7 @@
w = self.unbounded_window
result1 = df.withColumn("v", self.pandas_agg_mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("v", mean(df["v"]).over(w))
+ expected1 = df.withColumn("v", sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -213,7 +202,7 @@
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v", mean_udf(df["v"] * 2).over(w) + 1)
- expected1 = df.withColumn("v", mean(df["v"] * 2).over(w) + 1)
+ expected1 = df.withColumn("v", sf.mean(df["v"] * 2).over(w) + 1)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -226,10 +215,10 @@
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v2", plus_one(mean_udf(plus_one(df["v"])).over(w)))
- expected1 = df.withColumn("v2", plus_one(mean(plus_one(df["v"])).over(w)))
+ expected1 = df.withColumn("v2", plus_one(sf.mean(plus_one(df["v"])).over(w)))
result2 = df.withColumn("v2", time_two(mean_udf(time_two(df["v"])).over(w)))
- expected2 = df.withColumn("v2", time_two(mean(time_two(df["v"])).over(w)))
+ expected2 = df.withColumn("v2", time_two(sf.mean(time_two(df["v"])).over(w)))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -240,10 +229,10 @@
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v2", mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("v2", mean(df["v"]).over(w))
+ expected1 = df.withColumn("v2", sf.mean(df["v"]).over(w))
result2 = df.select(mean_udf(df["v"]).over(w))
- expected2 = df.select(mean(df["v"]).over(w))
+ expected2 = df.select(sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -256,26 +245,28 @@
min_udf = self.pandas_agg_min_udf
result1 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - min_udf(df["v"]).over(w))
- expected1 = df.withColumn("v_diff", max(df["v"]).over(w) - min(df["v"]).over(w))
+ expected1 = df.withColumn("v_diff", sf.max(df["v"]).over(w) - sf.min(df["v"]).over(w))
# Test mixing sql window function and window udf in the same expression
- result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - min(df["v"]).over(w))
+ result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - sf.min(df["v"]).over(w))
expected2 = expected1
# Test chaining sql aggregate function and udf
result3 = (
df.withColumn("max_v", max_udf(df["v"]).over(w))
- .withColumn("min_v", min(df["v"]).over(w))
- .withColumn("v_diff", col("max_v") - col("min_v"))
+ .withColumn("min_v", sf.min(df["v"]).over(w))
+ .withColumn("v_diff", sf.col("max_v") - sf.col("min_v"))
.drop("max_v", "min_v")
)
expected3 = expected1
# Test mixing sql window function and udf
result4 = df.withColumn("max_v", max_udf(df["v"]).over(w)).withColumn(
- "rank", rank().over(ow)
+ "rank", sf.rank().over(ow)
)
- expected4 = df.withColumn("max_v", max(df["v"]).over(w)).withColumn("rank", rank().over(ow))
+ expected4 = df.withColumn("max_v", sf.max(df["v"]).over(w)).withColumn(
+ "rank", sf.rank().over(ow)
+ )
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -303,8 +294,6 @@
df.withColumn("v2", foo_udf(df["v"]).over(w)).schema
def test_bounded_simple(self):
- from pyspark.sql.functions import mean, max, min, count
-
df = self.data
w1 = self.sliding_row_window
w2 = self.shrinking_range_window
@@ -323,17 +312,15 @@
)
expected1 = (
- df.withColumn("mean_v", mean(plus_one(df["v"])).over(w1))
- .withColumn("count_v", count(df["v"]).over(w2))
- .withColumn("max_v", max(df["v"]).over(w2))
- .withColumn("min_v", min(df["v"]).over(w1))
+ df.withColumn("mean_v", sf.mean(plus_one(df["v"])).over(w1))
+ .withColumn("count_v", sf.count(df["v"]).over(w2))
+ .withColumn("max_v", sf.max(df["v"]).over(w2))
+ .withColumn("min_v", sf.min(df["v"]).over(w1))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_growing_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.growing_row_window
w2 = self.growing_range_window
@@ -344,15 +331,13 @@
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_sliding_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.sliding_row_window
w2 = self.sliding_range_window
@@ -363,15 +348,13 @@
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_shrinking_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.shrinking_row_window
w2 = self.shrinking_range_window
@@ -382,15 +365,13 @@
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_bounded_mixed(self):
- from pyspark.sql.functions import mean, max
-
df = self.data
w1 = self.sliding_row_window
w2 = self.unbounded_window
@@ -405,9 +386,9 @@
)
expected1 = (
- df.withColumn("mean_v", mean(df["v"]).over(w1))
- .withColumn("max_v", max(df["v"]).over(w2))
- .withColumn("mean_unbounded_v", mean(df["v"]).over(w1))
+ df.withColumn("mean_v", sf.mean(df["v"]).over(w1))
+ .withColumn("max_v", sf.max(df["v"]).over(w2))
+ .withColumn("mean_unbounded_v", sf.mean(df["v"]).over(w1))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -425,7 +406,7 @@
]
):
with self.subTest(bound=bound, query_no=i):
- assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))
+ assertDataFrameEqual(windowed, df.withColumn("wm", sf.mean(df.v).over(w)))
with self.tempView("v"):
df.createOrReplaceTempView("v")
@@ -521,7 +502,7 @@
]
):
with self.subTest(bound=bound, query_no=i):
- assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w)))
+ assertDataFrameEqual(windowed, df.withColumn("wm", sf.mean(df.v).over(w)))
with self.tempView("v"):
df.createOrReplaceTempView("v")
@@ -608,6 +589,50 @@
result = df.select(mean_udf(df["v"]).over(w)).first()[0]
assert result == 123
+ def test_arrow_batch_slicing(self):
+ df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
+
+ w1 = Window.partitionBy("key").orderBy("v")
+ w2 = (
+ Window.partitionBy("key")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_sum(v):
+ return v.sum()
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_sum_unbounded(v):
+ assert len(v) == 1000 / 2, len(v)
+ return v.sum()
+
+ expected1 = df.select("*", sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+ expected2 = df.select("*", sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+ for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result1 = (
+ df.select("*", pandas_sum("v").over(w1).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected1, result1)
+
+ result2 = (
+ df.select("*", pandas_sum_unbounded("v").over(w2).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected2, result2)
+
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c3ba8bc..d94ba8f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -52,7 +52,6 @@
from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion
from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
- AggArrowUDFSerializer,
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupPandasUDFSerializer,
@@ -67,6 +66,7 @@
TransformWithStateInPySparkRowSerializer,
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamArrowUDFSerializer,
+ ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
ArrowStreamArrowUDTFSerializer,
@@ -2612,11 +2612,15 @@
or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
):
ser = GroupArrowUDFSerializer(_assign_cols_by_name)
- elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
- ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
+ elif eval_type in (
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
+ ):
+ ser = ArrowStreamAggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled
@@ -2703,7 +2707,6 @@
elif eval_type in (
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
- PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
):
# Arrow cast and safe check are always enabled
ser = ArrowStreamArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 1643a8d..82c03b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -368,7 +368,7 @@
}
}
- val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
+ val runner = new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
evalType,
argMetas,
@@ -378,7 +378,9 @@
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- profiler).compute(pythonInput, context.partitionId(), context)
+ profiler) with GroupedPythonArrowInput
+
+ val windowFunctionResult = runner.compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow