blob: 34f24d44830649c54b35f70f4d12bfea6f2f0a0c [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import logging
from pyspark.sql.functions import arrow_udf, ArrowUDFType
from pyspark.util import PythonEvalType, is_remote_only
from pyspark.sql import Row
from pyspark.sql.types import (
ArrayType,
YearMonthIntervalType,
StructType,
StructField,
VariantType,
VariantVal,
)
from pyspark.sql import functions as sf
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.utils import (
have_numpy,
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
assertDataFrameEqual,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
from typing import Iterator, Tuple
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class GroupedAggArrowUDFTestsMixin:
@property
def data(self):
return (
self.spark.range(10)
.toDF("id")
.withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i in range(20, 30)]))
.withColumn("v", sf.explode(sf.col("vs")))
.drop("vs")
.withColumn("w", sf.lit(1.0))
)
@property
def python_plus_one(self):
@sf.udf("double")
def plus_one(v):
assert isinstance(v, (int, float))
return float(v + 1)
return plus_one
@property
def arrow_scalar_plus_two(self):
import pyarrow as pa
@arrow_udf("double", ArrowUDFType.SCALAR)
def plus_two(v):
assert isinstance(v, pa.Array)
return pa.compute.add(v, 2).cast(pa.float64())
return plus_two
@property
def arrow_agg_mean_udf(self):
import pyarrow as pa
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def avg(v):
assert isinstance(v, pa.Array)
return pa.compute.mean(v.cast(pa.float64()))
return avg
@property
def arrow_agg_mean_arr_udf(self):
import pyarrow as pa
@arrow_udf("array<double>", ArrowUDFType.GROUPED_AGG)
def avg(v):
assert isinstance(v, pa.Array)
assert isinstance(v, pa.ListArray)
return [pa.compute.mean(v.flatten()).cast(pa.float64())]
return avg
@property
def arrow_agg_sum_udf(self):
import pyarrow as pa
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def sum(v):
assert isinstance(v, pa.Array)
return pa.compute.sum(v).cast(pa.float64())
return sum
@property
def arrow_agg_weighted_mean_udf(self):
import pyarrow as pa
import numpy as np
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def weighted_mean(v, w):
assert isinstance(v, pa.Array)
assert isinstance(w, pa.Array)
return float(np.average(v, weights=w))
return weighted_mean
def test_manual(self):
df = self.data
sum_udf = self.arrow_agg_sum_udf
mean_udf = self.arrow_agg_mean_udf
mean_arr_udf = self.arrow_agg_mean_arr_udf
result = (
df.groupby("id")
.agg(sum_udf(df.v), mean_udf(df.v), mean_arr_udf(sf.array(df.v)))
.sort("id")
)
expected = self.spark.createDataFrame(
[
[0, 245.0, 24.5, [24.5]],
[1, 255.0, 25.5, [25.5]],
[2, 265.0, 26.5, [26.5]],
[3, 275.0, 27.5, [27.5]],
[4, 285.0, 28.5, [28.5]],
[5, 295.0, 29.5, [29.5]],
[6, 305.0, 30.5, [30.5]],
[7, 315.0, 31.5, [31.5]],
[8, 325.0, 32.5, [32.5]],
[9, 335.0, 33.5, [33.5]],
],
["id", "sum(v)", "avg(v)", "avg(array(v))"],
).collect()
self.assertEqual(expected, result.collect())
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_basic(self):
df = self.data
weighted_mean_udf = self.arrow_agg_weighted_mean_udf
# Groupby one column and aggregate one UDF with literal
result1 = df.groupby("id").agg(weighted_mean_udf(df.v, sf.lit(1.0))).sort("id")
expected1 = (
df.groupby("id").agg(sf.mean(df.v).alias("weighted_mean(v, 1.0)")).sort("id").collect()
)
self.assertEqual(expected1, result1.collect())
# Groupby one expression and aggregate one UDF with literal
result2 = (
df.groupby((sf.col("id") + 1)).agg(weighted_mean_udf(df.v, sf.lit(1.0))).sort(df.id + 1)
)
expected2 = (
df.groupby((sf.col("id") + 1))
.agg(sf.mean(df.v).alias("weighted_mean(v, 1.0)"))
.sort(df.id + 1)
).collect()
self.assertEqual(expected2, result2.collect())
# Groupby one column and aggregate one UDF without literal
result3 = df.groupby("id").agg(weighted_mean_udf(df.v, df.w)).sort("id")
expected3 = (
df.groupby("id").agg(sf.mean(df.v).alias("weighted_mean(v, w)")).sort("id").collect()
)
self.assertEqual(expected3, result3.collect())
# Groupby one expression and aggregate one UDF without literal
result4 = (
df.groupby((sf.col("id") + 1).alias("id")).agg(weighted_mean_udf(df.v, df.w)).sort("id")
)
expected4 = (
df.groupby((sf.col("id") + 1).alias("id"))
.agg(sf.mean(df.v).alias("weighted_mean(v, w)"))
.sort("id")
).collect()
self.assertEqual(expected4, result4.collect())
def test_alias(self):
df = self.data
mean_udf = self.arrow_agg_mean_udf
result = df.groupby("id").agg(mean_udf(df.v).alias("mean_alias"))
expected = df.groupby("id").agg(sf.mean(df.v).alias("mean_alias")).collect()
self.assertEqual(expected, result.collect())
def test_mixed_sql(self):
"""
Test mixing group aggregate pandas UDF with sql expression.
"""
df = self.data
sum_udf = self.arrow_agg_sum_udf
# Mix group aggregate pandas UDF with sql expression
result1 = df.groupby("id").agg(sum_udf(df.v) + 1).sort("id")
expected1 = df.groupby("id").agg(sf.sum(df.v) + 1).sort("id").collect()
# Mix group aggregate pandas UDF with sql expression (order swapped)
result2 = df.groupby("id").agg(sum_udf(df.v + 1)).sort("id")
expected2 = df.groupby("id").agg(sf.sum(df.v + 1)).sort("id").collect()
# Wrap group aggregate pandas UDF with two sql expressions
result3 = df.groupby("id").agg(sum_udf(df.v + 1) + 2).sort("id")
expected3 = df.groupby("id").agg(sf.sum(df.v + 1) + 2).sort("id").collect()
self.assertEqual(expected1, result1.collect())
self.assertEqual(expected2, result2.collect())
self.assertEqual(expected3, result3.collect())
def test_mixed_udfs(self):
"""
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
"""
df = self.data
plus_one = self.python_plus_one
plus_two = self.arrow_scalar_plus_two
sum_udf = self.arrow_agg_sum_udf
# Mix group aggregate pandas UDF and python UDF
result1 = df.groupby("id").agg(plus_one(sum_udf(df.v))).sort("id")
expected1 = df.groupby("id").agg(plus_one(sf.sum(df.v))).sort("id").collect()
# Mix group aggregate pandas UDF and python UDF (order swapped)
result2 = df.groupby("id").agg(sum_udf(plus_one(df.v))).sort("id")
expected2 = df.groupby("id").agg(sf.sum(plus_one(df.v))).sort("id").collect()
# Mix group aggregate pandas UDF and scalar pandas UDF
result3 = df.groupby("id").agg(sum_udf(plus_two(df.v))).sort("id")
expected3 = df.groupby("id").agg(sf.sum(plus_two(df.v))).sort("id").collect()
# Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
result4 = df.groupby("id").agg(plus_two(sum_udf(df.v))).sort("id")
expected4 = df.groupby("id").agg(plus_two(sf.sum(df.v))).sort("id").collect()
# Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
result5 = (
df.groupby(plus_one(df.id)).agg(plus_one(sum_udf(plus_one(df.v)))).sort("plus_one(id)")
)
expected5 = (
df.groupby(plus_one(df.id)).agg(plus_one(sf.sum(plus_one(df.v)))).sort("plus_one(id)")
).collect()
# Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
# groupby
result6 = (
df.groupby(plus_two(df.id)).agg(plus_two(sum_udf(plus_two(df.v)))).sort("plus_two(id)")
)
expected6 = (
df.groupby(plus_two(df.id)).agg(plus_two(sf.sum(plus_two(df.v)))).sort("plus_two(id)")
).collect()
self.assertEqual(expected1, result1.collect())
self.assertEqual(expected2, result2.collect())
self.assertEqual(expected3, result3.collect())
self.assertEqual(expected4, result4.collect())
self.assertEqual(expected5, result5.collect())
self.assertEqual(expected6, result6.collect())
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
"""
df = self.data
mean_udf = self.arrow_agg_mean_udf
sum_udf = self.arrow_agg_sum_udf
weighted_mean_udf = self.arrow_agg_weighted_mean_udf
result = (
df.groupBy("id")
.agg(mean_udf(df.v), sum_udf(df.v), weighted_mean_udf(df.v, df.w))
.sort("id")
)
expected = (
df.groupBy("id")
.agg(sf.mean(df.v), sf.sum(df.v), sf.mean(df.v).alias("weighted_mean(v, w)"))
.sort("id")
.collect()
)
self.assertEqual(expected, result.collect())
def test_complex_groupby(self):
df = self.data
sum_udf = self.arrow_agg_sum_udf
plus_one = self.python_plus_one
plus_two = self.arrow_scalar_plus_two
# groupby one expression
result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
expected1 = df.groupby(df.v % 2).agg(sf.sum(df.v)).collect()
# empty groupby
result2 = df.groupby().agg(sum_udf(df.v))
expected2 = df.groupby().agg(sf.sum(df.v)).collect()
# groupby one column and one sql expression
result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
expected3 = df.groupby(df.id, df.v % 2).agg(sf.sum(df.v)).orderBy(df.id, df.v % 2).collect()
# groupby one python UDF
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)).sort("plus_one(id)")
expected4 = df.groupby(plus_one(df.id)).agg(sf.sum(df.v)).sort("plus_one(id)").collect()
# groupby one scalar pandas UDF
result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)).sort("sum(v)")
expected5 = df.groupby(plus_two(df.id)).agg(sf.sum(df.v)).sort("sum(v)").collect()
# groupby one expression and one python UDF
result6 = (
df.groupby(df.v % 2, plus_one(df.id))
.agg(sum_udf(df.v))
.sort(["(v % 2)", "plus_one(id)"])
)
expected6 = (
df.groupby(df.v % 2, plus_one(df.id))
.agg(sf.sum(df.v))
.sort(["(v % 2)", "plus_one(id)"])
).collect()
# groupby one expression and one scalar pandas UDF
result7 = (
df.groupby(df.v % 2, plus_two(df.id))
.agg(sum_udf(df.v))
.sort(["sum(v)", "plus_two(id)"])
)
expected7 = (
df.groupby(df.v % 2, plus_two(df.id)).agg(sf.sum(df.v)).sort(["sum(v)", "plus_two(id)"])
).collect()
self.assertEqual(expected1, result1.collect())
self.assertEqual(expected2, result2.collect())
self.assertEqual(expected3, result3.collect())
self.assertEqual(expected4, result4.collect())
self.assertEqual(expected5, result5.collect())
self.assertEqual(expected6, result6.collect())
self.assertEqual(expected7, result7.collect())
def test_complex_expressions(self):
df = self.data
plus_one = self.python_plus_one
plus_two = self.arrow_scalar_plus_two
sum_udf = self.arrow_agg_sum_udf
# Test complex expressions with sql expression, python UDF and
# group aggregate pandas UDF
result1 = (
df.withColumn("v1", plus_one(df.v))
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum_udf(sf.col("v")),
sum_udf(sf.col("v1") + 3),
sum_udf(sf.col("v2")) + 5,
plus_one(sum_udf(sf.col("v1"))),
sum_udf(plus_one(sf.col("v2"))),
)
.sort(["id", "(v % 2)"])
)
expected1 = (
df.withColumn("v1", df.v + 1)
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sf.sum(sf.col("v")),
sf.sum(sf.col("v1") + 3),
sf.sum(sf.col("v2")) + 5,
plus_one(sf.sum(sf.col("v1"))),
sf.sum(plus_one(sf.col("v2"))),
)
.sort(["id", "(v % 2)"])
.collect()
)
# Test complex expressions with sql expression, scala pandas UDF and
# group aggregate pandas UDF
result2 = (
df.withColumn("v1", plus_one(df.v))
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum_udf(sf.col("v")),
sum_udf(sf.col("v1") + 3),
sum_udf(sf.col("v2")) + 5,
plus_two(sum_udf(sf.col("v1"))),
sum_udf(plus_two(sf.col("v2"))),
)
.sort(["id", "(v % 2)"])
)
expected2 = (
df.withColumn("v1", df.v + 1)
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sf.sum(sf.col("v")),
sf.sum(sf.col("v1") + 3),
sf.sum(sf.col("v2")) + 5,
plus_two(sf.sum(sf.col("v1"))),
sf.sum(plus_two(sf.col("v2"))),
)
.sort(["id", "(v % 2)"])
.collect()
)
# Test sequential groupby aggregate
result3 = (
df.groupby("id")
.agg(sum_udf(df.v).alias("v"))
.groupby("id")
.agg(sum_udf(sf.col("v")))
.sort("id")
)
expected3 = (
df.groupby("id")
.agg(sf.sum(df.v).alias("v"))
.groupby("id")
.agg(sf.sum(sf.col("v")))
.sort("id")
.collect()
)
self.assertEqual(expected1, result1.collect())
self.assertEqual(expected2, result2.collect())
self.assertEqual(expected3, result3.collect())
def test_retain_group_columns(self):
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.arrow_agg_sum_udf
result1 = df.groupby(df.id).agg(sum_udf(df.v))
expected1 = df.groupby(df.id).agg(sf.sum(df.v)).collect()
self.assertEqual(expected1, result1.collect())
def test_array_type(self):
df = self.data
array_udf = arrow_udf(lambda x: [1.0, 2.0], "array<double>", ArrowUDFType.GROUPED_AGG)
result1 = df.groupby("id").agg(array_udf(df["v"]).alias("v2"))
self.assertEqual(result1.first()["v2"], [1.0, 2.0])
def test_invalid_args(self):
with self.quiet():
self.check_invalid_args()
def check_invalid_args(self):
df = self.data
plus_one = self.python_plus_one
mean_udf = self.arrow_agg_mean_udf
with self.assertRaisesRegex(AnalysisException, "[MISSING_AGGREGATION]"):
df.groupby(df.id).agg(plus_one(df.v)).collect()
with self.assertRaisesRegex(
AnalysisException, "aggregate function.*argument.*aggregate function"
):
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
with self.assertRaisesRegex(
AnalysisException,
"The group aggregate pandas UDF `avg` cannot be invoked together with as other, "
"non-pandas aggregate functions.",
):
df.groupby(df.id).agg(mean_udf(df.v), sf.mean(df.v)).collect()
def test_register_vectorized_udf_basic(self):
import pyarrow as pa
sum_arrow_udf = arrow_udf(
lambda v: pa.compute.sum(v).cast(pa.int32()),
"integer",
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
)
self.assertEqual(sum_arrow_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
with self.temp_func("sum_arrow_udf"):
group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf", sum_arrow_udf)
self.assertEqual(
group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
)
q = """
SELECT sum_arrow_udf(v1)
FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2
"""
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)
def test_grouped_with_empty_partition(self):
import pyarrow as pa
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(data).repartition(num_parts)
f = arrow_udf(lambda x: pa.compute.sum(x).cast(pa.int32()), "int", ArrowUDFType.GROUPED_AGG)
result = df.groupBy("id").agg(f(df["x"]).alias("sum")).sort("id")
self.assertEqual(result.collect(), expected)
def test_grouped_without_group_by_clause(self):
import pyarrow as pa
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def max_udf(v):
return float(pa.compute.max(v).as_py())
df = self.spark.range(0, 100)
with self.temp_view("table"), self.temp_func("max_udf"):
df.createTempView("table")
self.spark.udf.register("max_udf", max_udf)
agg1 = df.agg(max_udf(df["id"]))
agg2 = self.spark.sql("select max_udf(id) from table")
self.assertEqual(agg1.collect(), agg2.collect())
def test_no_predicate_pushdown_through(self):
import pyarrow as pa
@arrow_udf("float", ArrowUDFType.GROUPED_AGG)
def mean(x):
return pa.compute.mean(x).cast(pa.float32())
df = self.spark.createDataFrame([Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)])
agg = df.groupBy("id").agg(mean("foo").alias("mean"))
filtered = agg.filter(agg["mean"] > 40.0)
self.assertEqual(filtered.collect()[0]["mean"], 42.0)
assert filtered.collect()[0]["mean"] == 42.0
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
with self.temp_view("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
self.assertEqual(
aggregated.collect(),
df.groupby("id").agg(sf.mean(df.v).alias("wm")).collect(),
)
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
with self.temp_view("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got multiple values for argument 'v'"
):
self.spark.sql(
"SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id"
).show()
def test_kwargs(self):
df = self.data
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np
return np.average(kwargs["v"], weights=kwargs["w"])
with self.temp_view("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
self.assertEqual(
aggregated.collect(),
df.groupby("id").agg(sf.mean(df.v).alias("wm")).collect(),
)
# negative
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()
def test_named_arguments_and_defaults(self):
import pyarrow as pa
df = self.data
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def biased_sum(v, w=None):
return pa.compute.sum(v).as_py() + (pa.compute.sum(w).as_py() if w is not None else 100)
with self.temp_view("v"), self.temp_func("biased_sum"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)
# without "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=False, query_no=i):
self.assertEqual(
aggregated.collect(),
df.groupby("id").agg((sf.sum(df.v) + sf.lit(100)).alias("s")).collect(),
)
# with "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=True, query_no=i):
self.assertEqual(
aggregated.collect(),
df.groupby("id").agg((sf.sum(df.v) + sf.sum(df.w)).alias("s")).collect(),
)
def test_complex_agg_collect_set(self):
import pyarrow as pa
df = self.spark.createDataFrame([(1, 1), (1, 2), (2, 3), (2, 5), (2, 1)], ("id", "v"))
@arrow_udf("array<int>")
def arrow_collect_set(v: pa.Array) -> pa.Scalar:
assert isinstance(v, pa.Array), str(type(v))
s = sorted([x.as_py() for x in pa.compute.unique(v)])
t = pa.list_(pa.int32())
return pa.scalar(value=s, type=t)
result1 = df.select(
arrow_collect_set(df["id"]).alias("ids"),
arrow_collect_set(df["v"]).alias("vs"),
)
expected1 = df.select(
sf.sort_array(sf.collect_set(df["id"])).alias("ids"),
sf.sort_array(sf.collect_set(df["v"])).alias("vs"),
)
self.assertEqual(expected1.collect(), result1.collect())
def test_complex_agg_collect_list(self):
import pyarrow as pa
df = self.spark.createDataFrame([(1, 1), (1, 2), (2, 3), (2, 5), (2, 1)], ("id", "v"))
@arrow_udf("array<int>")
def arrow_collect_list(v: pa.Array) -> pa.Scalar:
assert isinstance(v, pa.Array), str(type(v))
s = sorted([x.as_py() for x in v])
t = pa.list_(pa.int32())
return pa.scalar(value=s, type=t)
result1 = df.select(
arrow_collect_list(df["id"]).alias("ids"),
arrow_collect_list(df["v"]).alias("vs"),
)
expected1 = df.select(
sf.sort_array(sf.collect_list(df["id"])).alias("ids"),
sf.sort_array(sf.collect_list(df["v"])).alias("vs"),
)
self.assertEqual(expected1.collect(), result1.collect())
def test_complex_agg_collect_as_map(self):
import pyarrow as pa
df = self.spark.createDataFrame([(1, 1), (2, 2), (3, 5)], ("id", "v"))
@arrow_udf("map<int, int>")
def arrow_collect_as_map(id: pa.Array, v: pa.Array) -> pa.Scalar:
assert isinstance(id, pa.Array), str(type(id))
assert isinstance(v, pa.Array), str(type(v))
d = {i: j for i, j in zip(id.to_pylist(), v.to_pylist())}
t = pa.map_(pa.int32(), pa.int32())
return pa.scalar(value=d, type=t)
result1 = df.select(
arrow_collect_as_map("id", "v").alias("map"),
)
expected1 = df.select(
sf.map_from_arrays(sf.collect_list("id"), sf.collect_list("v")).alias("map"),
)
self.assertEqual(expected1.collect(), result1.collect())
def test_complex_agg_min_max_struct(self):
import pyarrow as pa
df = self.spark.createDataFrame([(1, 1), (2, 2), (3, 5)], ("id", "v"))
@arrow_udf("struct<m1: int, m2:int>")
def arrow_collect_min_max(id: pa.Array, v: pa.Array) -> pa.Scalar:
assert isinstance(id, pa.Array), str(type(id))
assert isinstance(v, pa.Array), str(type(v))
m1 = pa.compute.min(id)
m2 = pa.compute.max(v)
t = pa.struct([pa.field("m1", pa.int32()), pa.field("m2", pa.int32())])
return pa.scalar(value={"m1": m1.as_py(), "m2": m2.as_py()}, type=t)
result1 = df.select(
arrow_collect_min_max("id", "v").alias("struct"),
)
expected1 = df.select(
sf.struct(sf.min("id").alias("m1"), sf.max("v").alias("m2")).alias("struct"),
)
self.assertEqual(expected1.collect(), result1.collect())
def test_time_min(self):
import pyarrow as pa
df = self.spark.sql(
"""
SELECT * FROM VALUES
(1, TIME '12:34:56'),
(1, TIME '1:2:3'),
(2, TIME '0:58:59'),
(2, TIME '10:58:59'),
(2, TIME '10:00:03')
AS tab(i, t)
"""
)
@arrow_udf("time", ArrowUDFType.GROUPED_AGG)
def agg_min_time(v):
assert isinstance(v, pa.Array)
assert isinstance(v, pa.Time64Array)
return pa.compute.min(v)
expected1 = df.select(sf.min("t").alias("res"))
result1 = df.select(agg_min_time("t").alias("res"))
self.assertEqual(expected1.collect(), result1.collect())
expected2 = df.groupby("i").agg(sf.min("t").alias("res")).sort("i")
result2 = df.groupby("i").agg(agg_min_time("t").alias("res")).sort("i")
self.assertEqual(expected2.collect(), result2.collect())
def test_input_output_variant(self):
import pyarrow as pa
@arrow_udf("variant")
def first_variant(v: pa.Array) -> pa.Scalar:
assert isinstance(v, pa.Array)
assert isinstance(v, pa.StructArray)
assert isinstance(v.field("metadata"), pa.BinaryArray)
assert isinstance(v.field("value"), pa.BinaryArray)
return v[0]
@arrow_udf("variant")
def last_variant(v: pa.Array) -> pa.Scalar:
assert isinstance(v, pa.Array)
assert isinstance(v, pa.StructArray)
assert isinstance(v.field("metadata"), pa.BinaryArray)
assert isinstance(v.field("value"), pa.BinaryArray)
return v[-1]
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
result = df.select(
first_variant("v").alias("first"),
last_variant("v").alias("last"),
)
self.assertEqual(
result.schema,
StructType(
[
StructField("first", VariantType(), True),
StructField("last", VariantType(), True),
]
),
)
row = result.first()
self.assertIsInstance(row.first, VariantVal)
self.assertIsInstance(row.last, VariantVal)
def test_return_type_coercion(self):
import pyarrow as pa
df = self.spark.range(10)
@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
def agg_long(id: pa.Array) -> int:
assert isinstance(id, pa.Array), str(type(id))
return pa.scalar(value=len(id), type=pa.int64())
result1 = df.select(agg_long("id").alias("res"))
self.assertEqual(1, len(result1.collect()))
# long -> int coercion
@arrow_udf("int", ArrowUDFType.GROUPED_AGG)
def agg_int1(id: pa.Array) -> int:
assert isinstance(id, pa.Array), str(type(id))
return pa.scalar(value=len(id), type=pa.int64())
result2 = df.select(agg_int1("id").alias("res"))
self.assertEqual(1, len(result2.collect()))
# long -> int coercion, overflow
@arrow_udf("int", ArrowUDFType.GROUPED_AGG)
def agg_int2(id: pa.Array) -> int:
assert isinstance(id, pa.Array), str(type(id))
return pa.scalar(value=len(id) + 2147483647, type=pa.int64())
result3 = df.select(agg_int2("id").alias("res"))
with self.assertRaises(Exception):
# pyarrow.lib.ArrowInvalid:
# Integer value 2147483657 not in range: -2147483648 to 2147483647
result3.collect()
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_return_numpy_scalar(self):
import numpy as np
import pyarrow as pa
@arrow_udf("long")
def np_max_udf(v: pa.Array) -> np.int64:
assert isinstance(v, pa.Array)
return np.max(v)
@arrow_udf("long")
def np_min_udf(v: pa.Array) -> np.int64:
assert isinstance(v, pa.Array)
return np.min(v)
@arrow_udf("double")
def np_avg_udf(v: pa.Array) -> np.float64:
assert isinstance(v, pa.Array)
return np.mean(v)
df = self.spark.range(10)
expected = df.select(
sf.max("id").alias("max"),
sf.min("id").alias("min"),
sf.avg("id").alias("avg"),
)
result = df.select(
np_max_udf("id").alias("max"),
np_min_udf("id").alias("min"),
np_avg_udf("id").alias("avg"),
)
self.assertEqual(expected.collect(), result.collect())
def test_unsupported_return_types(self):
import pyarrow as pa
with self.quiet():
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type with grouped aggregate "
"Arrow UDFs.*ArrayType.*YearMonthIntervalType",
):
arrow_udf(
lambda x: x,
ArrayType(ArrayType(YearMonthIntervalType())),
ArrowUDFType.GROUPED_AGG,
)
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type with grouped aggregate "
"Arrow UDFs.*ArrayType.*YearMonthIntervalType",
):
@arrow_udf(ArrayType(ArrayType(YearMonthIntervalType())), ArrowUDFType.GROUPED_AGG)
def func_a(a: pa.Array) -> pa.Scalar:
return pa.compute.max(a)
def test_0_args(self):
import pyarrow as pa
df = self.spark.range(10).withColumn("k", sf.col("id") % 3)
@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
def arrow_max(v) -> int:
return pa.compute.max(v).as_py()
@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
def arrow_lit_1() -> int:
return 1
expected1 = df.select(sf.max("id").alias("res1"), sf.lit(1).alias("res1"))
result1 = df.select(arrow_max("id").alias("res1"), arrow_lit_1().alias("res1"))
self.assertEqual(expected1.collect(), result1.collect())
expected2 = (
df.groupby("k")
.agg(
sf.max("id").alias("res1"),
sf.lit(1).alias("res1"),
)
.sort("k")
)
result2 = (
df.groupby("k")
.agg(
arrow_max("id").alias("res1"),
arrow_lit_1().alias("res1"),
)
.sort("k")
)
self.assertEqual(expected2.collect(), result2.collect())
def test_arrow_batch_slicing(self):
import pyarrow as pa
df = self.spark.range(10000000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
def arrow_max(v):
assert len(v) == 10000000 / 2, len(v)
return pa.compute.max(v)
expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect()
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df.groupBy("key").agg(arrow_max("v").alias("res")).sort("key")
).collect()
self.assertEqual(expected, result)
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_grouped_agg_arrow_udf_with_logging(self):
import pyarrow as pa
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def my_grouped_agg_arrow_udf(x):
assert isinstance(x, pa.Array)
logger = logging.getLogger("test_grouped_agg_arrow")
logger.warning(f"grouped agg arrow udf: {len(x)}")
return pa.compute.sum(x)
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")),
[Row(id=1, result=3.0), Row(id=2, result=18.0)],
)
logs = self.spark.tvf.python_worker_logs()
assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"grouped agg arrow udf: {n}",
context={"func_name": my_grouped_agg_arrow_udf.__name__},
logger="test_grouped_agg_arrow",
)
for n in [2, 3]
],
)
def test_iterator_grouped_agg_single_column(self):
"""
Test iterator API for grouped aggregation with single column.
"""
import pyarrow as pa
from typing import Iterator
@arrow_udf("double")
def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
assert isinstance(v, pa.Array)
sum_val += pa.compute.sum(v).as_py()
cnt += len(v)
return sum_val / cnt if cnt > 0 else 0.0
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
result = df.groupby("id").agg(arrow_mean_iter(df["v"]).alias("mean")).sort("id")
expected = df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
self.assertEqual(expected, result.collect())
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_iterator_grouped_agg_multiple_columns(self):
"""
Test iterator API for grouped aggregation with multiple columns.
"""
import pyarrow as pa
import numpy as np
@arrow_udf("double")
def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
weighted_sum = 0.0
weight = 0.0
for v, w in it:
assert isinstance(v, pa.Array)
assert isinstance(w, pa.Array)
weighted_sum += np.dot(v, w)
weight += pa.compute.sum(w).as_py()
return weighted_sum / weight if weight > 0 else 0.0
df = self.spark.createDataFrame(
[(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
("id", "v", "w"),
)
result = (
df.groupby("id")
.agg(arrow_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
.sort("id")
.collect()
)
# Expected weighted means:
# Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
# Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 / 6.0
expected = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
self.assertEqual(len(result), len(expected))
for r, (exp_id, exp_wm) in zip(result, expected):
self.assertEqual(r["id"], exp_id)
self.assertAlmostEqual(r["wm"], exp_wm, places=5)
def test_iterator_grouped_agg_eval_type(self):
"""
Test that the eval type is correctly inferred for iterator grouped agg UDFs.
"""
import pyarrow as pa
from typing import Iterator
@arrow_udf("double")
def arrow_sum_iter(it: Iterator[pa.Array]) -> float:
total = 0.0
for v in it:
total += pa.compute.sum(v).as_py()
return total
self.assertEqual(arrow_sum_iter.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF)
def test_iterator_grouped_agg_partial_consumption(self):
"""
Test that iterator grouped agg UDF can partially consume batches.
This ensures that batches are processed one by one without loading all data into memory.
"""
import pyarrow as pa
from typing import Iterator
# Create a dataset with multiple batches per group
# Use small batch size to ensure multiple batches per group
# Use same value for all data points to avoid ordering issues
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
df = self.spark.createDataFrame(
[(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v")
)
@arrow_udf("struct<count:bigint,sum:double>")
def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict:
# Only consume first two batches, then return
# This tests that partial consumption works correctly
total = 0.0
count = 0
for i, v in enumerate(it):
if i < 2: # Only process first 2 batches
total += pa.compute.sum(v).as_py()
count += len(v)
else:
# Stop early - partial consumption
break
return {"count": count, "sum": total}
result = (
df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id")
)
# Verify results are correct for partial consumption
# With batch size = 2:
# Group 1 (id=1): 4 values in 2 batches -> processes both batches
# Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0]
# Result: count=4, sum=4.0
# Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available)
# Batch 1: [1.0, 1.0]
# Result: count=2, sum=2.0
actual = result.collect()
self.assertEqual(len(actual), 2, "Should have results for both groups")
# Verify both groups were processed correctly
# Group 1: processes 2 batches (all available)
group1_result = next(row for row in actual if row["id"] == 1)
self.assertEqual(
group1_result["result"]["count"],
4,
msg="Group 1 should process 4 values (2 batches)",
)
self.assertAlmostEqual(
group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0"
)
# Group 2: processes 1 batch (only batch available)
group2_result = next(row for row in actual if row["id"] == 2)
self.assertEqual(
group2_result["result"]["count"],
2,
msg="Group 2 should process 2 values (1 batch)",
)
self.assertAlmostEqual(
group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0"
)
def test_iterator_grouped_agg_sql_single_column(self):
"""
Test iterator API for grouped aggregation with single column in SQL.
"""
import pyarrow as pa
@arrow_udf("double")
def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
assert isinstance(v, pa.Array)
sum_val += pa.compute.sum(v).as_py()
cnt += len(v)
return sum_val / cnt if cnt > 0 else 0.0
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
with self.temp_view("test_table"), self.temp_func("arrow_mean_iter"):
df.createOrReplaceTempView("test_table")
self.spark.udf.register("arrow_mean_iter", arrow_mean_iter)
# Test SQL query with GROUP BY
result_sql = self.spark.sql(
"SELECT id, arrow_mean_iter(v) as mean FROM test_table GROUP BY id ORDER BY id"
)
expected = df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
self.assertEqual(expected, result_sql.collect())
def test_iterator_grouped_agg_sql_multiple_columns(self):
"""
Test iterator API for grouped aggregation with multiple columns in SQL.
"""
import pyarrow as pa
@arrow_udf("double")
def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
weighted_sum = 0.0
weight = 0.0
for v, w in it:
assert isinstance(v, pa.Array)
assert isinstance(w, pa.Array)
weighted_sum += pa.compute.sum(pa.compute.multiply(v, w)).as_py()
weight += pa.compute.sum(w).as_py()
return weighted_sum / weight if weight > 0 else 0.0
df = self.spark.createDataFrame(
[(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
("id", "v", "w"),
)
with self.temp_view("test_table"), self.temp_func("arrow_weighted_mean_iter"):
df.createOrReplaceTempView("test_table")
self.spark.udf.register("arrow_weighted_mean_iter", arrow_weighted_mean_iter)
# Test SQL query with GROUP BY and multiple columns
result_sql = self.spark.sql(
"SELECT id, arrow_weighted_mean_iter(v, w) as wm "
"FROM test_table GROUP BY id ORDER BY id"
)
# Expected weighted means:
# Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
# Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 / 6.0
expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
actual_results = result_sql.collect()
self.assertEqual(actual_results, expected)
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.testing import main
main()