blob: 1059af59f4a84ff5db535f49d93d1c277d4602e6 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import cast
from pyspark.util import PythonEvalType
from pyspark.sql import Row
from pyspark.sql.functions import (
array,
explode,
col,
lit,
mean,
sum,
udf,
pandas_udf,
PandasUDFType,
)
from pyspark.sql.types import ArrayType, YearMonthIntervalType
from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import assertDataFrameEqual
if have_pandas:
import pandas as pd
from pandas.testing import assert_frame_equal
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class GroupedAggPandasUDFTestsMixin:
@property
def data(self):
return (
self.spark.range(10)
.toDF("id")
.withColumn("vs", array([lit(i * 1.0) + col("id") for i in range(20, 30)]))
.withColumn("v", explode(col("vs")))
.drop("vs")
.withColumn("w", lit(1.0))
)
@property
def python_plus_one(self):
@udf("double")
def plus_one(v):
assert isinstance(v, (int, float))
return float(v + 1)
return plus_one
@property
def pandas_scalar_plus_two(self):
@pandas_udf("double", PandasUDFType.SCALAR)
def plus_two(v):
assert isinstance(v, pd.Series)
return v + 2
return plus_two
@property
def pandas_agg_mean_udf(self):
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
return avg
@property
def pandas_agg_sum_udf(self):
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def sum(v):
return v.sum()
return sum
@property
def pandas_agg_weighted_mean_udf(self):
import numpy as np
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
return np.average(v, weights=w)
return weighted_mean
def test_manual(self):
df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf
mean_arr_udf = pandas_udf(
self.pandas_agg_mean_udf.func,
ArrayType(self.pandas_agg_mean_udf.returnType),
self.pandas_agg_mean_udf.evalType,
)
result1 = (
df.groupby("id")
.agg(sum_udf(df.v), mean_udf(df.v), mean_arr_udf(array(df.v)))
.sort("id")
)
expected1 = self.spark.createDataFrame(
[
[0, 245.0, 24.5, [24.5]],
[1, 255.0, 25.5, [25.5]],
[2, 265.0, 26.5, [26.5]],
[3, 275.0, 27.5, [27.5]],
[4, 285.0, 28.5, [28.5]],
[5, 295.0, 29.5, [29.5]],
[6, 305.0, 30.5, [30.5]],
[7, 315.0, 31.5, [31.5]],
[8, 325.0, 32.5, [32.5]],
[9, 335.0, 33.5, [33.5]],
],
["id", "sum(v)", "avg(v)", "avg(array(v))"],
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_basic(self):
df = self.data
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
# Groupby one column and aggregate one UDF with literal
result1 = df.groupby("id").agg(weighted_mean_udf(df.v, lit(1.0))).sort("id")
expected1 = df.groupby("id").agg(mean(df.v).alias("weighted_mean(v, 1.0)")).sort("id")
assert_frame_equal(expected1.toPandas(), result1.toPandas())
# Groupby one expression and aggregate one UDF with literal
result2 = df.groupby((col("id") + 1)).agg(weighted_mean_udf(df.v, lit(1.0))).sort(df.id + 1)
expected2 = (
df.groupby((col("id") + 1))
.agg(mean(df.v).alias("weighted_mean(v, 1.0)"))
.sort(df.id + 1)
)
assert_frame_equal(expected2.toPandas(), result2.toPandas())
# Groupby one column and aggregate one UDF without literal
result3 = df.groupby("id").agg(weighted_mean_udf(df.v, df.w)).sort("id")
expected3 = df.groupby("id").agg(mean(df.v).alias("weighted_mean(v, w)")).sort("id")
assert_frame_equal(expected3.toPandas(), result3.toPandas())
# Groupby one expression and aggregate one UDF without literal
result4 = (
df.groupby((col("id") + 1).alias("id")).agg(weighted_mean_udf(df.v, df.w)).sort("id")
)
expected4 = (
df.groupby((col("id") + 1).alias("id"))
.agg(mean(df.v).alias("weighted_mean(v, w)"))
.sort("id")
)
assert_frame_equal(expected4.toPandas(), result4.toPandas())
def test_unsupported_types(self):
with self.quiet():
self.check_unsupported_types()
def check_unsupported_types(self):
with self.assertRaises(PySparkNotImplementedError) as pe:
pandas_udf(
lambda x: x,
ArrayType(ArrayType(YearMonthIntervalType())),
PandasUDFType.GROUPED_AGG,
)
self.check_error(
exception=pe.exception,
errorClass="NOT_IMPLEMENTED",
messageParameters={
"feature": "Invalid return type with grouped aggregate Pandas UDFs: "
"ArrayType(ArrayType(YearMonthIntervalType(0, 1), True), True)"
},
)
with self.assertRaises(PySparkNotImplementedError) as pe:
@pandas_udf("mean double, std double", PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return v.mean(), v.std()
self.check_error(
exception=pe.exception,
errorClass="NOT_IMPLEMENTED",
messageParameters={
"feature": "Invalid return type with grouped aggregate Pandas UDFs: "
"StructType([StructField('mean', DoubleType(), True), "
"StructField('std', DoubleType(), True)])"
},
)
with self.assertRaises(PySparkNotImplementedError) as pe:
@pandas_udf(ArrayType(YearMonthIntervalType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v): # noqa: F811
return {v.mean(): v.std()}
self.check_error(
exception=pe.exception,
errorClass="NOT_IMPLEMENTED",
messageParameters={
"feature": "Invalid return type with grouped aggregate Pandas UDFs: "
"ArrayType(YearMonthIntervalType(0, 1), True)"
},
)
def test_alias(self):
df = self.data
mean_udf = self.pandas_agg_mean_udf
result1 = df.groupby("id").agg(mean_udf(df.v).alias("mean_alias"))
expected1 = df.groupby("id").agg(mean(df.v).alias("mean_alias"))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_mixed_sql(self):
"""
Test mixing group aggregate pandas UDF with sql expression.
"""
df = self.data
sum_udf = self.pandas_agg_sum_udf
# Mix group aggregate pandas UDF with sql expression
result1 = df.groupby("id").agg(sum_udf(df.v) + 1).sort("id")
expected1 = df.groupby("id").agg(sum(df.v) + 1).sort("id")
# Mix group aggregate pandas UDF with sql expression (order swapped)
result2 = df.groupby("id").agg(sum_udf(df.v + 1)).sort("id")
expected2 = df.groupby("id").agg(sum(df.v + 1)).sort("id")
# Wrap group aggregate pandas UDF with two sql expressions
result3 = df.groupby("id").agg(sum_udf(df.v + 1) + 2).sort("id")
expected3 = df.groupby("id").agg(sum(df.v + 1) + 2).sort("id")
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())
def test_mixed_udfs(self):
"""
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
"""
df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
sum_udf = self.pandas_agg_sum_udf
# Mix group aggregate pandas UDF and python UDF
result1 = df.groupby("id").agg(plus_one(sum_udf(df.v))).sort("id")
expected1 = df.groupby("id").agg(plus_one(sum(df.v))).sort("id")
# Mix group aggregate pandas UDF and python UDF (order swapped)
result2 = df.groupby("id").agg(sum_udf(plus_one(df.v))).sort("id")
expected2 = df.groupby("id").agg(sum(plus_one(df.v))).sort("id")
# Mix group aggregate pandas UDF and scalar pandas UDF
result3 = df.groupby("id").agg(sum_udf(plus_two(df.v))).sort("id")
expected3 = df.groupby("id").agg(sum(plus_two(df.v))).sort("id")
# Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
result4 = df.groupby("id").agg(plus_two(sum_udf(df.v))).sort("id")
expected4 = df.groupby("id").agg(plus_two(sum(df.v))).sort("id")
# Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
result5 = (
df.groupby(plus_one(df.id)).agg(plus_one(sum_udf(plus_one(df.v)))).sort("plus_one(id)")
)
expected5 = (
df.groupby(plus_one(df.id)).agg(plus_one(sum(plus_one(df.v)))).sort("plus_one(id)")
)
# Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
# groupby
result6 = (
df.groupby(plus_two(df.id)).agg(plus_two(sum_udf(plus_two(df.v)))).sort("plus_two(id)")
)
expected6 = (
df.groupby(plus_two(df.id)).agg(plus_two(sum(plus_two(df.v)))).sort("plus_two(id)")
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected4.toPandas(), result4.toPandas())
assert_frame_equal(expected5.toPandas(), result5.toPandas())
assert_frame_equal(expected6.toPandas(), result6.toPandas())
def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
"""
df = self.data
mean_udf = self.pandas_agg_mean_udf
sum_udf = self.pandas_agg_sum_udf
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
result1 = (
df.groupBy("id")
.agg(mean_udf(df.v), sum_udf(df.v), weighted_mean_udf(df.v, df.w))
.sort("id")
.toPandas()
)
expected1 = (
df.groupBy("id")
.agg(mean(df.v), sum(df.v), mean(df.v).alias("weighted_mean(v, w)"))
.sort("id")
.toPandas()
)
assert_frame_equal(expected1, result1)
def test_complex_groupby(self):
df = self.data
sum_udf = self.pandas_agg_sum_udf
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
# groupby one expression
result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
expected1 = df.groupby(df.v % 2).agg(sum(df.v))
# empty groupby
result2 = df.groupby().agg(sum_udf(df.v))
expected2 = df.groupby().agg(sum(df.v))
# groupby one column and one sql expression
result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
# groupby one python UDF
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)).sort("plus_one(id)")
expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)).sort("plus_one(id)")
# groupby one scalar pandas UDF
result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)).sort("sum(v)")
expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)).sort("sum(v)")
# groupby one expression and one python UDF
result6 = (
df.groupby(df.v % 2, plus_one(df.id))
.agg(sum_udf(df.v))
.sort(["(v % 2)", "plus_one(id)"])
)
expected6 = (
df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)).sort(["(v % 2)", "plus_one(id)"])
)
# groupby one expression and one scalar pandas UDF
result7 = (
df.groupby(df.v % 2, plus_two(df.id))
.agg(sum_udf(df.v))
.sort(["sum(v)", "plus_two(id)"])
)
expected7 = (
df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort(["sum(v)", "plus_two(id)"])
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected4.toPandas(), result4.toPandas())
assert_frame_equal(expected5.toPandas(), result5.toPandas())
assert_frame_equal(expected6.toPandas(), result6.toPandas())
assert_frame_equal(expected7.toPandas(), result7.toPandas())
def test_complex_expressions(self):
df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
sum_udf = self.pandas_agg_sum_udf
# Test complex expressions with sql expression, python UDF and
# group aggregate pandas UDF
result1 = (
df.withColumn("v1", plus_one(df.v))
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum_udf(col("v")),
sum_udf(col("v1") + 3),
sum_udf(col("v2")) + 5,
plus_one(sum_udf(col("v1"))),
sum_udf(plus_one(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
.sort_values(by=["id", "(v % 2)"])
)
expected1 = (
df.withColumn("v1", df.v + 1)
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum(col("v")),
sum(col("v1") + 3),
sum(col("v2")) + 5,
plus_one(sum(col("v1"))),
sum(plus_one(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
.sort_values(by=["id", "(v % 2)"])
)
# Test complex expressions with sql expression, scala pandas UDF and
# group aggregate pandas UDF
result2 = (
df.withColumn("v1", plus_one(df.v))
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum_udf(col("v")),
sum_udf(col("v1") + 3),
sum_udf(col("v2")) + 5,
plus_two(sum_udf(col("v1"))),
sum_udf(plus_two(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
.sort_values(by=["id", "(v % 2)"])
)
expected2 = (
df.withColumn("v1", df.v + 1)
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
sum(col("v")),
sum(col("v1") + 3),
sum(col("v2")) + 5,
plus_two(sum(col("v1"))),
sum(plus_two(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
.sort_values(by=["id", "(v % 2)"])
)
# Test sequential groupby aggregate
result3 = (
df.groupby("id")
.agg(sum_udf(df.v).alias("v"))
.groupby("id")
.agg(sum_udf(col("v")))
.sort("id")
.toPandas()
)
expected3 = (
df.groupby("id")
.agg(sum(df.v).alias("v"))
.groupby("id")
.agg(sum(col("v")))
.sort("id")
.toPandas()
)
assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)
def test_retain_group_columns(self):
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf
result1 = df.groupby(df.id).agg(sum_udf(df.v))
expected1 = df.groupby(df.id).agg(sum(df.v))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_array_type(self):
df = self.data
array_udf = pandas_udf(lambda x: [1.0, 2.0], "array<double>", PandasUDFType.GROUPED_AGG)
result1 = df.groupby("id").agg(array_udf(df["v"]).alias("v2"))
self.assertEqual(result1.first()["v2"], [1.0, 2.0])
def test_invalid_args(self):
with self.quiet():
self.check_invalid_args()
def check_invalid_args(self):
df = self.data
plus_one = self.python_plus_one
mean_udf = self.pandas_agg_mean_udf
with self.assertRaisesRegex(AnalysisException, "[MISSING_AGGREGATION]"):
df.groupby(df.id).agg(plus_one(df.v)).collect()
with self.assertRaisesRegex(
AnalysisException, "aggregate function.*argument.*aggregate function"
):
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
with self.assertRaisesRegex(
AnalysisException,
"The group aggregate pandas UDF `avg` cannot be invoked together with as other, "
"non-pandas aggregate functions.",
):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
def test_register_vectorized_udf_basic(self):
sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
)
self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)
def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(data).repartition(num_parts)
f = pandas_udf(lambda x: x.sum(), "int", PandasUDFType.GROUPED_AGG)
result = df.groupBy("id").agg(f(df["x"]).alias("sum")).sort("id").collect()
self.assertEqual(result, expected)
def test_grouped_without_group_by_clause(self):
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def max_udf(v):
return v.max()
df = self.spark.range(0, 100)
self.spark.udf.register("max_udf", max_udf)
with self.tempView("table"):
df.createTempView("table")
agg1 = df.agg(max_udf(df["id"]))
agg2 = self.spark.sql("select max_udf(id) from table")
assert_frame_equal(agg1.toPandas(), agg2.toPandas())
def test_no_predicate_pushdown_through(self):
# SPARK-30921: We should not pushdown predicates of PythonUDFs through Aggregate.
import numpy as np
@pandas_udf("float", PandasUDFType.GROUPED_AGG)
def mean(x):
return np.mean(x)
df = self.spark.createDataFrame([Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)])
agg = df.groupBy("id").agg(mean("foo").alias("mean"))
filtered = agg.filter(agg["mean"] > 40.0)
assert filtered.collect()[0]["mean"] == 42.0
def test_named_arguments(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf
with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))
def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf
with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got multiple values for argument 'v'"
):
self.spark.sql(
"SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id"
).show()
def test_kwargs(self):
df = self.data
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np
return np.average(kwargs["v"], weights=kwargs["w"])
with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))
# negative
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()
def test_named_arguments_and_defaults(self):
df = self.data
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def biased_sum(v, w=None):
return v.sum() + (w.sum() if w is not None else 100)
with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)
# without "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=False, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s"))
)
# with "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=True, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s"))
)
def test_arrow_cast_enabled_numeric_to_decimal(self):
import numpy as np
from decimal import Decimal
columns = [
"int8",
"int16",
"int32",
"uint8",
"uint16",
"uint32",
"float64",
]
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
df = self.spark.range(2).repartition(1)
for column in columns:
with self.subTest(column=column):
@pandas_udf("decimal(10,0)", PandasUDFType.GROUPED_AGG)
def test(series):
return pdf[column].iloc[0]
row = df.groupby("id").agg(test(df.id)).first()
res = row[1]
self.assertEqual(res, Decimal("1"))
def test_arrow_cast_enabled_str_to_numeric(self):
df = self.spark.range(2).repartition(1)
types = ["int", "long", "float", "double"]
for type_str in types:
with self.subTest(type=type_str):
@pandas_udf(type_str, PandasUDFType.GROUPED_AGG)
def test(series):
return 123
row = df.groupby("id").agg(test(df.id)).first()
self.assertEqual(row[1], 123)
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)